mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-14 10:27:48 -05:00
perf(build): locking and improve build speed (#669)
* revert(build): not locking packages Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * perf: improve svars generation and unifying envvar parsing Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * docs: update changelog Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * chore: update stubs check for mypy Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,11 +1,8 @@
|
||||
# mypy: disable-error-code="call-arg,misc,attr-defined,type-abstract,type-arg,valid-type,arg-type"
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import _service_vars as svars
|
||||
import orjson
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
@@ -14,18 +11,17 @@ from bentoml.io import JSON, Text
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
svars.model_id,
|
||||
model_id=svars.model_id,
|
||||
model_tag=svars.model_tag,
|
||||
prompt_template=openllm.utils.first_not_none(os.getenv('OPENLLM_PROMPT_TEMPLATE'), None),
|
||||
system_message=openllm.utils.first_not_none(os.getenv('OPENLLM_SYSTEM_MESSAGE'), None),
|
||||
serialisation=openllm.utils.first_not_none(os.getenv('OPENLLM_SERIALIZATION'), 'safetensors'),
|
||||
adapter_map=orjson.loads(svars.adapter_map),
|
||||
trust_remote_code=openllm.utils.check_bool_env('TRUST_REMOTE_CODE', default=False),
|
||||
prompt_template=svars.prompt_template,
|
||||
system_message=svars.system_message,
|
||||
serialisation=svars.serialization,
|
||||
adapter_map=svars.adapter_map,
|
||||
trust_remote_code=svars.trust_remote_code,
|
||||
)
|
||||
llm_config = llm.config
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[llm.runner])
|
||||
svc = bentoml.Service(name=f"llm-{llm.config['start_name']}-service", runners=[llm.runner])
|
||||
|
||||
llm_model_class = openllm.GenerationInput.from_llm_config(llm_config)
|
||||
llm_model_class = openllm.GenerationInput.from_llm_config(llm.config)
|
||||
|
||||
|
||||
@svc.api(
|
||||
@@ -49,11 +45,11 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s
|
||||
|
||||
|
||||
_Metadata = openllm.MetadataOutput(
|
||||
timeout=llm_config['timeout'],
|
||||
model_name=llm_config['model_name'],
|
||||
timeout=llm.config['timeout'],
|
||||
model_name=llm.config['model_name'],
|
||||
backend=llm.__llm_backend__,
|
||||
model_id=llm.model_id,
|
||||
configuration=llm_config.model_dump_json().decode(),
|
||||
configuration=llm.config.model_dump_json().decode(),
|
||||
prompt_template=llm.runner.prompt_template,
|
||||
system_message=llm.runner.system_message,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
import os
|
||||
|
||||
model_id = os.environ['OPENLLM_MODEL_ID'] # openllm: model name
|
||||
model_tag = None # openllm: model tag
|
||||
adapter_map = os.environ['OPENLLM_ADAPTER_MAP'] # openllm: model adapter map
|
||||
import orjson
|
||||
|
||||
from openllm_core.utils import ENV_VARS_TRUE_VALUES
|
||||
|
||||
model_id = os.environ['OPENLLM_MODEL_ID']
|
||||
model_tag = None
|
||||
adapter_map = orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None)))
|
||||
prompt_template = os.getenv('OPENLLM_PROMPT_TEMPLATE')
|
||||
system_message = os.getenv('OPENLLM_SYSTEM_MESSAGE')
|
||||
serialization = os.getenv('OPENLLM_SERIALIZATION', default='safetensors')
|
||||
trust_remote_code = str(os.getenv('TRUST_REMOTE_CODE', default=str(False))).upper() in ENV_VARS_TRUE_VALUES
|
||||
|
||||
11
openllm-python/src/openllm/_service_vars.pyi
Normal file
11
openllm-python/src/openllm/_service_vars.pyi
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from openllm_core._typing_compat import LiteralSerialisation
|
||||
|
||||
model_id: str = ...
|
||||
model_tag: Optional[str] = ...
|
||||
adapter_map: Optional[Dict[str, str]] = ...
|
||||
prompt_template: Optional[str] = ...
|
||||
system_message: Optional[str] = ...
|
||||
serialization: LiteralSerialisation = ...
|
||||
trust_remote_code: bool = ...
|
||||
@@ -1,3 +1,9 @@
|
||||
import orjson
|
||||
|
||||
model_id = '{__model_id__}' # openllm: model id
|
||||
model_tag = '{__model_tag__}' # openllm: model tag
|
||||
adapter_map = """{__model_adapter_map__}""" # openllm: model adapter map
|
||||
adapter_map = orjson.loads("""{__model_adapter_map__}""") # openllm: model adapter map
|
||||
serialization = '{__model_serialization__}' # openllm: model serialization
|
||||
prompt_template = {__model_prompt_template__} # openllm: model prompt template
|
||||
system_message = {__model_system_message__} # openllm: model system message
|
||||
trust_remote_code = {__model_trust_remote_code__} # openllm: model trust remote code
|
||||
|
||||
@@ -54,28 +54,17 @@ def build_editable(path, package='openllm'):
|
||||
|
||||
|
||||
def construct_python_options(llm, llm_fs, extra_dependencies=None, adapter_map=None):
|
||||
packages = ['openllm', 'scipy'] # apparently bnb misses this one
|
||||
packages = ['scipy', 'bentoml[tracing]==1.1.9'] # apparently bnb misses this one
|
||||
if adapter_map is not None:
|
||||
packages += ['openllm[fine-tune]']
|
||||
# NOTE: add openllm to the default dependencies
|
||||
# if users has openllm custom built wheels, it will still respect
|
||||
# that since bentoml will always install dependencies from requirements.txt
|
||||
# first, then proceed to install everything inside the wheels/ folder.
|
||||
if extra_dependencies is not None:
|
||||
packages += [f'openllm[{k}]' for k in extra_dependencies]
|
||||
|
||||
req = llm.config['requirements']
|
||||
if req is not None:
|
||||
packages.extend(req)
|
||||
if str(os.environ.get('BENTOML_BUNDLE_LOCAL_BUILD', False)).lower() == 'false':
|
||||
packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}")
|
||||
|
||||
# XXX: Currently locking this for correctness
|
||||
packages.extend(['torch==2.0.1+cu118', 'vllm==0.2.1.post1', 'xformers==0.0.22', 'bentoml[tracing]==1.1.9'])
|
||||
wheels = []
|
||||
if llm.config['requirements'] is not None:
|
||||
packages.extend(llm.config['requirements'])
|
||||
wheels = None
|
||||
built_wheels = [build_editable(llm_fs.getsyspath('/'), p) for p in ('openllm_core', 'openllm_client', 'openllm')]
|
||||
if all(i for i in built_wheels):
|
||||
wheels.extend([llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in t.cast(t.List[str], built_wheels)])
|
||||
wheels = [llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels]
|
||||
return PythonOptions(
|
||||
packages=packages,
|
||||
wheels=wheels,
|
||||
@@ -90,30 +79,25 @@ def construct_python_options(llm, llm_fs, extra_dependencies=None, adapter_map=N
|
||||
def construct_docker_options(
|
||||
llm, _, quantize, adapter_map, dockerfile_template, serialisation, container_registry, container_version_strategy
|
||||
):
|
||||
from openllm_cli._factory import parse_config_options
|
||||
from openllm_cli.entrypoint import process_environ
|
||||
|
||||
environ = parse_config_options(llm.config, llm.config['timeout'], 1.0, None, True, os.environ.copy())
|
||||
env_dict = {
|
||||
'TORCH_DTYPE': str(llm._torch_dtype).split('.')[-1],
|
||||
'OPENLLM_BACKEND': llm.__llm_backend__,
|
||||
'OPENLLM_CONFIG': f"'{llm.config.model_dump_json(flatten=True).decode()}'",
|
||||
'OPENLLM_SERIALIZATION': serialisation,
|
||||
'BENTOML_DEBUG': str(True),
|
||||
'BENTOML_QUIET': str(False),
|
||||
'BENTOML_CONFIG_OPTIONS': f"'{environ['BENTOML_CONFIG_OPTIONS']}'",
|
||||
'TRUST_REMOTE_CODE': str(llm.trust_remote_code),
|
||||
}
|
||||
if adapter_map:
|
||||
env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
if llm._system_message:
|
||||
env_dict['OPENLLM_SYSTEM_MESSAGE'] = repr(llm._system_message)
|
||||
if llm._prompt_template:
|
||||
env_dict['OPENLLM_PROMPT_TEMPLATE'] = repr(llm._prompt_template.to_string())
|
||||
if quantize:
|
||||
env_dict['OPENLLM_QUANTIZE'] = str(quantize)
|
||||
environ = process_environ(
|
||||
llm.config,
|
||||
llm.config['timeout'],
|
||||
1.0,
|
||||
None,
|
||||
True,
|
||||
llm.model_id,
|
||||
None,
|
||||
llm._serialisation,
|
||||
llm,
|
||||
llm._system_message,
|
||||
llm._prompt_template,
|
||||
use_current_env=False,
|
||||
)
|
||||
return DockerOptions(
|
||||
base_image=oci.RefResolver.construct_base_image(container_registry, container_version_strategy),
|
||||
env=env_dict,
|
||||
env=environ,
|
||||
dockerfile_template=dockerfile_template,
|
||||
)
|
||||
|
||||
@@ -121,6 +105,10 @@ def construct_docker_options(
|
||||
OPENLLM_MODEL_ID = '# openllm: model id'
|
||||
OPENLLM_MODEL_TAG = '# openllm: model tag'
|
||||
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
|
||||
OPENLLM_MODEL_PROMPT_TEMPLATE = '# openllm: model prompt template'
|
||||
OPENLLM_MODEL_SYSTEM_MESSAGE = '# openllm: model system message'
|
||||
OPENLLM_MODEL_SERIALIZATION = '# openllm: model serialization'
|
||||
OPENLLM_MODEL_TRUST_REMOTE_CODE = '# openllm: model trust remote code'
|
||||
|
||||
|
||||
class _ServiceVarsFormatter(string.Formatter):
|
||||
@@ -156,6 +144,26 @@ class ModelAdapterMapFormatter(_ServiceVarsFormatter):
|
||||
identifier = OPENLLM_MODEL_ADAPTER_MAP
|
||||
|
||||
|
||||
class ModelPromptTemplateFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_prompt_template__'
|
||||
identifier = OPENLLM_MODEL_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
class ModelSystemMessageFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_system_message__'
|
||||
identifier = OPENLLM_MODEL_SYSTEM_MESSAGE
|
||||
|
||||
|
||||
class ModelSerializationFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_serialization__'
|
||||
identifier = OPENLLM_MODEL_SERIALIZATION
|
||||
|
||||
|
||||
class ModelTrustRemoteCodeFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_trust_remote_code__'
|
||||
identifier = OPENLLM_MODEL_TRUST_REMOTE_CODE
|
||||
|
||||
|
||||
_service_file = Path(os.path.abspath(__file__)).parent.parent / '_service.py'
|
||||
_service_vars_file = Path(os.path.abspath(__file__)).parent.parent / '_service_vars_pkg.py'
|
||||
|
||||
@@ -164,6 +172,8 @@ def write_service(llm, llm_fs, adapter_map):
|
||||
model_id_formatter = ModelIdFormatter(llm.model_id)
|
||||
model_tag_formatter = ModelTagFormatter(str(llm.tag))
|
||||
adapter_map_formatter = ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode())
|
||||
serialization_formatter = ModelSerializationFormatter(llm.config['serialisation'])
|
||||
trust_remote_code_formatter = ModelTrustRemoteCodeFormatter(str(llm.trust_remote_code))
|
||||
|
||||
logger.debug(
|
||||
'Generating service vars file for %s at %s (dir=%s)', llm.model_id, '_service_vars.py', llm_fs.getsyspath('/')
|
||||
@@ -177,6 +187,20 @@ def write_service(llm, llm_fs, adapter_map):
|
||||
src_contents[i] = model_tag_formatter.parse_line(it)
|
||||
elif adapter_map_formatter.identifier in it:
|
||||
src_contents[i] = adapter_map_formatter.parse_line(it)
|
||||
elif serialization_formatter.identifier in it:
|
||||
src_contents[i] = serialization_formatter.parse_line(it)
|
||||
elif trust_remote_code_formatter.identifier in it:
|
||||
src_contents[i] = trust_remote_code_formatter.parse_line(it)
|
||||
elif OPENLLM_MODEL_PROMPT_TEMPLATE in it:
|
||||
if llm._prompt_template:
|
||||
src_contents[i] = ModelPromptTemplateFormatter(f'"""{llm._prompt_template.to_string()}"""').parse_line(it)
|
||||
else:
|
||||
src_contents[i] = ModelPromptTemplateFormatter(str(None)).parse_line(it)
|
||||
elif OPENLLM_MODEL_SYSTEM_MESSAGE in it:
|
||||
if llm._system_message:
|
||||
src_contents[i] = ModelSystemMessageFormatter(f'"""{llm._system_message}"""').parse_line(it)
|
||||
else:
|
||||
src_contents[i] = ModelSystemMessageFormatter(str(None)).parse_line(it)
|
||||
|
||||
script = f"# GENERATED BY 'openllm build {llm.model_id}'. DO NOT EDIT\n\n" + ''.join(src_contents)
|
||||
if SHOW_CODEGEN:
|
||||
|
||||
@@ -622,9 +622,22 @@ def start_grpc_command(
|
||||
|
||||
|
||||
def process_environ(
|
||||
config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, system_message, prompt_template
|
||||
config,
|
||||
server_timeout,
|
||||
wpr,
|
||||
device,
|
||||
cors,
|
||||
model_id,
|
||||
adapter_map,
|
||||
serialisation,
|
||||
llm,
|
||||
system_message,
|
||||
prompt_template,
|
||||
use_current_env=True,
|
||||
) -> t.Dict[str, t.Any]:
|
||||
environ = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy())
|
||||
environ = parse_config_options(
|
||||
config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {}
|
||||
)
|
||||
environ.update(
|
||||
{
|
||||
'OPENLLM_MODEL_ID': model_id,
|
||||
@@ -1019,22 +1032,21 @@ def build_command(
|
||||
),
|
||||
)
|
||||
backend_warning(llm.__llm_backend__, build=True)
|
||||
|
||||
os.environ.update(
|
||||
{
|
||||
'TORCH_DTYPE': dtype,
|
||||
'OPENLLM_BACKEND': llm.__llm_backend__,
|
||||
'OPENLLM_SERIALIZATION': llm._serialisation,
|
||||
'OPENLLM_MODEL_ID': llm.model_id,
|
||||
'OPENLLM_ADAPTER_MAP': orjson.dumps(None).decode(),
|
||||
}
|
||||
**process_environ(
|
||||
llm.config,
|
||||
llm.config['timeout'],
|
||||
1.0,
|
||||
None,
|
||||
True,
|
||||
llm.model_id,
|
||||
None,
|
||||
llm._serialisation,
|
||||
llm,
|
||||
llm._system_message,
|
||||
llm._prompt_template,
|
||||
)
|
||||
)
|
||||
if llm.quantise:
|
||||
os.environ['OPENLLM_QUANTIZE'] = str(llm.quantise)
|
||||
if system_message:
|
||||
os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template:
|
||||
os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
|
||||
try:
|
||||
assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything
|
||||
@@ -1049,7 +1061,7 @@ def build_command(
|
||||
llm_fs.writetext('Dockerfile.template', dockerfile_template.read())
|
||||
dockerfile_template_path = llm_fs.getsyspath('/Dockerfile.template')
|
||||
|
||||
adapter_map: dict[str, str] | None = None
|
||||
adapter_map = None
|
||||
if adapter_id and not build_ctx:
|
||||
ctx.fail("'build_ctx' is required when '--adapter-id' is passsed.")
|
||||
if adapter_id:
|
||||
|
||||
Reference in New Issue
Block a user