perf: unify LLM interface (#518)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-06 20:39:43 -05:00
committed by GitHub
parent f2639879af
commit e2029c934b
136 changed files with 9646 additions and 11244 deletions

View File

@@ -11,6 +11,7 @@ import inflection
import orjson
from bentoml_cli.utils import BentoMLCommandGroup
from click import ClickException
from click import shell_completion as sc
from click.shell_completion import CompletionItem
@@ -28,6 +29,9 @@ from openllm_core._typing_compat import LiteralString
from openllm_core._typing_compat import ParamSpec
from openllm_core._typing_compat import get_literal_args
from openllm_core.utils import DEBUG
from openllm_core.utils import check_bool_env
from openllm_core.utils import first_not_none
from openllm_core.utils import is_vllm_available
from . import termui
@@ -62,7 +66,6 @@ def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_res
_bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)])
else:
_bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
_bentoml_config_options_opts.append(f'runners."llm-generic-embedding".resources.cpu={openllm.get_resource({"cpu":"system"},"cpu")}')
if cors:
_bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
_bentoml_config_options_opts.extend([f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])])
@@ -84,7 +87,8 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
except FileNotFoundError:
pass
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
if len(adapter_name) == 0: raise ClickException(f'Adapter name is required for {adapter_id}')
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0]
return None
def start_command_factory(group: click.Group, model: str, _context_settings: DictStrAny | None = None, _serve_grpc: bool = False) -> click.Command:
@@ -117,24 +121,23 @@ Available official model_id(s): [default: {llm_config['default_id']}]
@start_decorator(llm_config, serve_grpc=_serve_grpc)
@click.pass_context
def start_cmd(ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, system_message: str | None, prompt_template_file: t.IO[t.Any] | None,
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, backend: LiteralBackend,
serialisation: LiteralSerialisation | None, cors: bool, adapter_id: str | None, return_process: bool, **attrs: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
if _serialisation == 'safetensors' and quantize is not None and openllm_core.utils.check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, backend: LiteralBackend | None,
serialisation: LiteralSerialisation | None, cors: bool, adapter_id: str | None, return_process: bool, **attrs: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
if _serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
termui.echo(
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.",
fg='yellow')
termui.echo(f"Make sure to check out '{model_id}' repository to see if the weights is in '{_serialisation}' format if unsure.")
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
adapter_map: dict[str, str] | None = attrs.pop(_adapter_mapping_key, None)
config, server_attrs = llm_config.model_validate_click(**attrs)
server_timeout = openllm.utils.first_not_none(server_timeout, default=config['timeout'])
server_timeout = first_not_none(server_timeout, default=config['timeout'])
server_attrs.update({'working_dir': os.path.dirname(os.path.dirname(__file__)), 'timeout': server_timeout})
if _serve_grpc: server_attrs['grpc_protocol_version'] = 'v1'
# NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop('development')
server_attrs.setdefault('production', not development)
wpr = openllm.utils.first_not_none(workers_per_resource, default=config['workers_per_resource'])
wpr = first_not_none(workers_per_resource, default=config['workers_per_resource'])
if isinstance(wpr, str):
if wpr == 'round_robin': wpr = 1.0
@@ -151,7 +154,10 @@ Available official model_id(s): [default: {llm_config['default_id']}]
wpr = float(wpr)
# Create a new model env to work with the envvar during CLI invocation
env = openllm.utils.EnvVarMixin(config['model_name'], backend, model_id=model_id or config['default_id'], quantize=quantize)
env = openllm.utils.EnvVarMixin(config['model_name'],
backend=openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'),
model_id=model_id or config['default_id'],
quantize=quantize)
requirements = llm_config['requirements']
if requirements is not None and len(requirements) > 0:
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
@@ -176,16 +182,16 @@ Available official model_id(s): [default: {llm_config['default_id']}]
if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
llm = openllm.utils.infer_auto_class(env['backend_value']).for_model(model,
model_id=start_env[env.model_id],
model_version=model_version,
prompt_template=prompt_template,
system_message=system_message,
llm_config=config,
ensure_available=True,
adapter_map=adapter_map,
quantize=env['quantize_value'],
serialisation=_serialisation)
llm = openllm.LLM[t.Any, t.Any](model_id=start_env[env.model_id],
revision=model_version,
prompt_template=prompt_template,
system_message=system_message,
llm_config=config,
backend=env['backend_value'],
adapter_map=adapter_map,
quantize=env['quantize_value'],
serialisation=_serialisation)
llm.save_pretrained() # ensure_available = True
start_env.update({env.config: llm.config.model_dump_json().decode()})
server = bentoml.GrpcServer('_service:svc', **server_attrs) if _serve_grpc else bentoml.HTTPServer('_service:svc', **server_attrs)
@@ -382,8 +388,8 @@ def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
# XXX: remove the check for __args__ once we have ggml and mlc supports
return cli_option('--backend',
type=click.Choice(get_literal_args(LiteralBackend)[:-2]),
default='pt',
type=click.Choice(get_literal_args(LiteralBackend)[:2]),
default=None,
envvar='OPENLLM_BACKEND',
show_envvar=True,
help='The implementation for saving this LLM.',
@@ -396,7 +402,7 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
return cli_option('--quantise',
'--quantize',
'quantize',
type=click.Choice(['int8', 'int4', 'gptq']),
type=click.Choice(get_literal_args(LiteralQuantise)),
default=None,
envvar='OPENLLM_QUANTIZE',
show_envvar=True,