mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-05 23:56:47 -05:00
perf: unify LLM interface (#518)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -11,6 +11,7 @@ import inflection
|
||||
import orjson
|
||||
|
||||
from bentoml_cli.utils import BentoMLCommandGroup
|
||||
from click import ClickException
|
||||
from click import shell_completion as sc
|
||||
from click.shell_completion import CompletionItem
|
||||
|
||||
@@ -28,6 +29,9 @@ from openllm_core._typing_compat import LiteralString
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import get_literal_args
|
||||
from openllm_core.utils import DEBUG
|
||||
from openllm_core.utils import check_bool_env
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
from . import termui
|
||||
|
||||
@@ -62,7 +66,6 @@ def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_res
|
||||
_bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)])
|
||||
else:
|
||||
_bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
|
||||
_bentoml_config_options_opts.append(f'runners."llm-generic-embedding".resources.cpu={openllm.get_resource({"cpu":"system"},"cpu")}')
|
||||
if cors:
|
||||
_bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
|
||||
_bentoml_config_options_opts.extend([f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])])
|
||||
@@ -84,7 +87,8 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
|
||||
if len(adapter_name) == 0: raise ClickException(f'Adapter name is required for {adapter_id}')
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0]
|
||||
return None
|
||||
|
||||
def start_command_factory(group: click.Group, model: str, _context_settings: DictStrAny | None = None, _serve_grpc: bool = False) -> click.Command:
|
||||
@@ -117,24 +121,23 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
@start_decorator(llm_config, serve_grpc=_serve_grpc)
|
||||
@click.pass_context
|
||||
def start_cmd(ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, system_message: str | None, prompt_template_file: t.IO[t.Any] | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, backend: LiteralBackend,
|
||||
serialisation: LiteralSerialisation | None, cors: bool, adapter_id: str | None, return_process: bool, **attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
|
||||
if _serialisation == 'safetensors' and quantize is not None and openllm_core.utils.check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, backend: LiteralBackend | None,
|
||||
serialisation: LiteralSerialisation | None, cors: bool, adapter_id: str | None, return_process: bool, **attrs: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
|
||||
if _serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
|
||||
termui.echo(
|
||||
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.",
|
||||
fg='yellow')
|
||||
termui.echo(f"Make sure to check out '{model_id}' repository to see if the weights is in '{_serialisation}' format if unsure.")
|
||||
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
|
||||
adapter_map: dict[str, str] | None = attrs.pop(_adapter_mapping_key, None)
|
||||
config, server_attrs = llm_config.model_validate_click(**attrs)
|
||||
server_timeout = openllm.utils.first_not_none(server_timeout, default=config['timeout'])
|
||||
server_timeout = first_not_none(server_timeout, default=config['timeout'])
|
||||
server_attrs.update({'working_dir': os.path.dirname(os.path.dirname(__file__)), 'timeout': server_timeout})
|
||||
if _serve_grpc: server_attrs['grpc_protocol_version'] = 'v1'
|
||||
# NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream.
|
||||
development = server_attrs.pop('development')
|
||||
server_attrs.setdefault('production', not development)
|
||||
wpr = openllm.utils.first_not_none(workers_per_resource, default=config['workers_per_resource'])
|
||||
wpr = first_not_none(workers_per_resource, default=config['workers_per_resource'])
|
||||
|
||||
if isinstance(wpr, str):
|
||||
if wpr == 'round_robin': wpr = 1.0
|
||||
@@ -151,7 +154,10 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
wpr = float(wpr)
|
||||
|
||||
# Create a new model env to work with the envvar during CLI invocation
|
||||
env = openllm.utils.EnvVarMixin(config['model_name'], backend, model_id=model_id or config['default_id'], quantize=quantize)
|
||||
env = openllm.utils.EnvVarMixin(config['model_name'],
|
||||
backend=openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'),
|
||||
model_id=model_id or config['default_id'],
|
||||
quantize=quantize)
|
||||
requirements = llm_config['requirements']
|
||||
if requirements is not None and len(requirements) > 0:
|
||||
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
|
||||
@@ -176,16 +182,16 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
|
||||
llm = openllm.utils.infer_auto_class(env['backend_value']).for_model(model,
|
||||
model_id=start_env[env.model_id],
|
||||
model_version=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_config=config,
|
||||
ensure_available=True,
|
||||
adapter_map=adapter_map,
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=_serialisation)
|
||||
llm = openllm.LLM[t.Any, t.Any](model_id=start_env[env.model_id],
|
||||
revision=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_config=config,
|
||||
backend=env['backend_value'],
|
||||
adapter_map=adapter_map,
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=_serialisation)
|
||||
llm.save_pretrained() # ensure_available = True
|
||||
start_env.update({env.config: llm.config.model_dump_json().decode()})
|
||||
|
||||
server = bentoml.GrpcServer('_service:svc', **server_attrs) if _serve_grpc else bentoml.HTTPServer('_service:svc', **server_attrs)
|
||||
@@ -382,8 +388,8 @@ def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
|
||||
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
|
||||
# XXX: remove the check for __args__ once we have ggml and mlc supports
|
||||
return cli_option('--backend',
|
||||
type=click.Choice(get_literal_args(LiteralBackend)[:-2]),
|
||||
default='pt',
|
||||
type=click.Choice(get_literal_args(LiteralBackend)[:2]),
|
||||
default=None,
|
||||
envvar='OPENLLM_BACKEND',
|
||||
show_envvar=True,
|
||||
help='The implementation for saving this LLM.',
|
||||
@@ -396,7 +402,7 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
|
||||
return cli_option('--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=click.Choice(['int8', 'int4', 'gptq']),
|
||||
type=click.Choice(get_literal_args(LiteralQuantise)),
|
||||
default=None,
|
||||
envvar='OPENLLM_QUANTIZE',
|
||||
show_envvar=True,
|
||||
|
||||
Reference in New Issue
Block a user