perf(build): locking and improve build speed (#669)

* revert(build): not locking packages

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* perf: improve svars generation and unifying envvar parsing

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* docs: update changelog

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: update stubs check for mypy

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-16 06:27:45 -05:00
committed by GitHub
parent fce8f223f3
commit 8fdfd0491f
9 changed files with 138 additions and 74 deletions

View File

@@ -1,11 +1,8 @@
# mypy: disable-error-code="call-arg,misc,attr-defined,type-abstract,type-arg,valid-type,arg-type"
from __future__ import annotations
import logging
import os
import typing as t
import _service_vars as svars
import orjson
import bentoml
import openllm
@@ -14,18 +11,17 @@ from bentoml.io import JSON, Text
logger = logging.getLogger(__name__)
llm = openllm.LLM[t.Any, t.Any](
svars.model_id,
model_id=svars.model_id,
model_tag=svars.model_tag,
prompt_template=openllm.utils.first_not_none(os.getenv('OPENLLM_PROMPT_TEMPLATE'), None),
system_message=openllm.utils.first_not_none(os.getenv('OPENLLM_SYSTEM_MESSAGE'), None),
serialisation=openllm.utils.first_not_none(os.getenv('OPENLLM_SERIALIZATION'), 'safetensors'),
adapter_map=orjson.loads(svars.adapter_map),
trust_remote_code=openllm.utils.check_bool_env('TRUST_REMOTE_CODE', default=False),
prompt_template=svars.prompt_template,
system_message=svars.system_message,
serialisation=svars.serialization,
adapter_map=svars.adapter_map,
trust_remote_code=svars.trust_remote_code,
)
llm_config = llm.config
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[llm.runner])
svc = bentoml.Service(name=f"llm-{llm.config['start_name']}-service", runners=[llm.runner])
llm_model_class = openllm.GenerationInput.from_llm_config(llm_config)
llm_model_class = openllm.GenerationInput.from_llm_config(llm.config)
@svc.api(
@@ -49,11 +45,11 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s
_Metadata = openllm.MetadataOutput(
timeout=llm_config['timeout'],
model_name=llm_config['model_name'],
timeout=llm.config['timeout'],
model_name=llm.config['model_name'],
backend=llm.__llm_backend__,
model_id=llm.model_id,
configuration=llm_config.model_dump_json().decode(),
configuration=llm.config.model_dump_json().decode(),
prompt_template=llm.runner.prompt_template,
system_message=llm.runner.system_message,
)

View File

@@ -1,5 +1,13 @@
import os
model_id = os.environ['OPENLLM_MODEL_ID'] # openllm: model name
model_tag = None # openllm: model tag
adapter_map = os.environ['OPENLLM_ADAPTER_MAP'] # openllm: model adapter map
import orjson
from openllm_core.utils import ENV_VARS_TRUE_VALUES
model_id = os.environ['OPENLLM_MODEL_ID']
model_tag = None
adapter_map = orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None)))
prompt_template = os.getenv('OPENLLM_PROMPT_TEMPLATE')
system_message = os.getenv('OPENLLM_SYSTEM_MESSAGE')
serialization = os.getenv('OPENLLM_SERIALIZATION', default='safetensors')
trust_remote_code = str(os.getenv('TRUST_REMOTE_CODE', default=str(False))).upper() in ENV_VARS_TRUE_VALUES

View File

@@ -0,0 +1,11 @@
from typing import Dict, Optional
from openllm_core._typing_compat import LiteralSerialisation
model_id: str = ...
model_tag: Optional[str] = ...
adapter_map: Optional[Dict[str, str]] = ...
prompt_template: Optional[str] = ...
system_message: Optional[str] = ...
serialization: LiteralSerialisation = ...
trust_remote_code: bool = ...

View File

@@ -1,3 +1,9 @@
import orjson
model_id = '{__model_id__}' # openllm: model id
model_tag = '{__model_tag__}' # openllm: model tag
adapter_map = """{__model_adapter_map__}""" # openllm: model adapter map
adapter_map = orjson.loads("""{__model_adapter_map__}""") # openllm: model adapter map
serialization = '{__model_serialization__}' # openllm: model serialization
prompt_template = {__model_prompt_template__} # openllm: model prompt template
system_message = {__model_system_message__} # openllm: model system message
trust_remote_code = {__model_trust_remote_code__} # openllm: model trust remote code

View File

@@ -54,28 +54,17 @@ def build_editable(path, package='openllm'):
def construct_python_options(llm, llm_fs, extra_dependencies=None, adapter_map=None):
packages = ['openllm', 'scipy'] # apparently bnb misses this one
packages = ['scipy', 'bentoml[tracing]==1.1.9'] # apparently bnb misses this one
if adapter_map is not None:
packages += ['openllm[fine-tune]']
# NOTE: add openllm to the default dependencies
# if users has openllm custom built wheels, it will still respect
# that since bentoml will always install dependencies from requirements.txt
# first, then proceed to install everything inside the wheels/ folder.
if extra_dependencies is not None:
packages += [f'openllm[{k}]' for k in extra_dependencies]
req = llm.config['requirements']
if req is not None:
packages.extend(req)
if str(os.environ.get('BENTOML_BUNDLE_LOCAL_BUILD', False)).lower() == 'false':
packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}")
# XXX: Currently locking this for correctness
packages.extend(['torch==2.0.1+cu118', 'vllm==0.2.1.post1', 'xformers==0.0.22', 'bentoml[tracing]==1.1.9'])
wheels = []
if llm.config['requirements'] is not None:
packages.extend(llm.config['requirements'])
wheels = None
built_wheels = [build_editable(llm_fs.getsyspath('/'), p) for p in ('openllm_core', 'openllm_client', 'openllm')]
if all(i for i in built_wheels):
wheels.extend([llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in t.cast(t.List[str], built_wheels)])
wheels = [llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels]
return PythonOptions(
packages=packages,
wheels=wheels,
@@ -90,30 +79,25 @@ def construct_python_options(llm, llm_fs, extra_dependencies=None, adapter_map=N
def construct_docker_options(
llm, _, quantize, adapter_map, dockerfile_template, serialisation, container_registry, container_version_strategy
):
from openllm_cli._factory import parse_config_options
from openllm_cli.entrypoint import process_environ
environ = parse_config_options(llm.config, llm.config['timeout'], 1.0, None, True, os.environ.copy())
env_dict = {
'TORCH_DTYPE': str(llm._torch_dtype).split('.')[-1],
'OPENLLM_BACKEND': llm.__llm_backend__,
'OPENLLM_CONFIG': f"'{llm.config.model_dump_json(flatten=True).decode()}'",
'OPENLLM_SERIALIZATION': serialisation,
'BENTOML_DEBUG': str(True),
'BENTOML_QUIET': str(False),
'BENTOML_CONFIG_OPTIONS': f"'{environ['BENTOML_CONFIG_OPTIONS']}'",
'TRUST_REMOTE_CODE': str(llm.trust_remote_code),
}
if adapter_map:
env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1')
if llm._system_message:
env_dict['OPENLLM_SYSTEM_MESSAGE'] = repr(llm._system_message)
if llm._prompt_template:
env_dict['OPENLLM_PROMPT_TEMPLATE'] = repr(llm._prompt_template.to_string())
if quantize:
env_dict['OPENLLM_QUANTIZE'] = str(quantize)
environ = process_environ(
llm.config,
llm.config['timeout'],
1.0,
None,
True,
llm.model_id,
None,
llm._serialisation,
llm,
llm._system_message,
llm._prompt_template,
use_current_env=False,
)
return DockerOptions(
base_image=oci.RefResolver.construct_base_image(container_registry, container_version_strategy),
env=env_dict,
env=environ,
dockerfile_template=dockerfile_template,
)
@@ -121,6 +105,10 @@ def construct_docker_options(
OPENLLM_MODEL_ID = '# openllm: model id'
OPENLLM_MODEL_TAG = '# openllm: model tag'
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
OPENLLM_MODEL_PROMPT_TEMPLATE = '# openllm: model prompt template'
OPENLLM_MODEL_SYSTEM_MESSAGE = '# openllm: model system message'
OPENLLM_MODEL_SERIALIZATION = '# openllm: model serialization'
OPENLLM_MODEL_TRUST_REMOTE_CODE = '# openllm: model trust remote code'
class _ServiceVarsFormatter(string.Formatter):
@@ -156,6 +144,26 @@ class ModelAdapterMapFormatter(_ServiceVarsFormatter):
identifier = OPENLLM_MODEL_ADAPTER_MAP
class ModelPromptTemplateFormatter(_ServiceVarsFormatter):
keyword = '__model_prompt_template__'
identifier = OPENLLM_MODEL_PROMPT_TEMPLATE
class ModelSystemMessageFormatter(_ServiceVarsFormatter):
keyword = '__model_system_message__'
identifier = OPENLLM_MODEL_SYSTEM_MESSAGE
class ModelSerializationFormatter(_ServiceVarsFormatter):
keyword = '__model_serialization__'
identifier = OPENLLM_MODEL_SERIALIZATION
class ModelTrustRemoteCodeFormatter(_ServiceVarsFormatter):
keyword = '__model_trust_remote_code__'
identifier = OPENLLM_MODEL_TRUST_REMOTE_CODE
_service_file = Path(os.path.abspath(__file__)).parent.parent / '_service.py'
_service_vars_file = Path(os.path.abspath(__file__)).parent.parent / '_service_vars_pkg.py'
@@ -164,6 +172,8 @@ def write_service(llm, llm_fs, adapter_map):
model_id_formatter = ModelIdFormatter(llm.model_id)
model_tag_formatter = ModelTagFormatter(str(llm.tag))
adapter_map_formatter = ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode())
serialization_formatter = ModelSerializationFormatter(llm.config['serialisation'])
trust_remote_code_formatter = ModelTrustRemoteCodeFormatter(str(llm.trust_remote_code))
logger.debug(
'Generating service vars file for %s at %s (dir=%s)', llm.model_id, '_service_vars.py', llm_fs.getsyspath('/')
@@ -177,6 +187,20 @@ def write_service(llm, llm_fs, adapter_map):
src_contents[i] = model_tag_formatter.parse_line(it)
elif adapter_map_formatter.identifier in it:
src_contents[i] = adapter_map_formatter.parse_line(it)
elif serialization_formatter.identifier in it:
src_contents[i] = serialization_formatter.parse_line(it)
elif trust_remote_code_formatter.identifier in it:
src_contents[i] = trust_remote_code_formatter.parse_line(it)
elif OPENLLM_MODEL_PROMPT_TEMPLATE in it:
if llm._prompt_template:
src_contents[i] = ModelPromptTemplateFormatter(f'"""{llm._prompt_template.to_string()}"""').parse_line(it)
else:
src_contents[i] = ModelPromptTemplateFormatter(str(None)).parse_line(it)
elif OPENLLM_MODEL_SYSTEM_MESSAGE in it:
if llm._system_message:
src_contents[i] = ModelSystemMessageFormatter(f'"""{llm._system_message}"""').parse_line(it)
else:
src_contents[i] = ModelSystemMessageFormatter(str(None)).parse_line(it)
script = f"# GENERATED BY 'openllm build {llm.model_id}'. DO NOT EDIT\n\n" + ''.join(src_contents)
if SHOW_CODEGEN:

View File

@@ -622,9 +622,22 @@ def start_grpc_command(
def process_environ(
config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, system_message, prompt_template
config,
server_timeout,
wpr,
device,
cors,
model_id,
adapter_map,
serialisation,
llm,
system_message,
prompt_template,
use_current_env=True,
) -> t.Dict[str, t.Any]:
environ = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy())
environ = parse_config_options(
config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {}
)
environ.update(
{
'OPENLLM_MODEL_ID': model_id,
@@ -1019,22 +1032,21 @@ def build_command(
),
)
backend_warning(llm.__llm_backend__, build=True)
os.environ.update(
{
'TORCH_DTYPE': dtype,
'OPENLLM_BACKEND': llm.__llm_backend__,
'OPENLLM_SERIALIZATION': llm._serialisation,
'OPENLLM_MODEL_ID': llm.model_id,
'OPENLLM_ADAPTER_MAP': orjson.dumps(None).decode(),
}
**process_environ(
llm.config,
llm.config['timeout'],
1.0,
None,
True,
llm.model_id,
None,
llm._serialisation,
llm,
llm._system_message,
llm._prompt_template,
)
)
if llm.quantise:
os.environ['OPENLLM_QUANTIZE'] = str(llm.quantise)
if system_message:
os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template:
os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
try:
assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything
@@ -1049,7 +1061,7 @@ def build_command(
llm_fs.writetext('Dockerfile.template', dockerfile_template.read())
dockerfile_template_path = llm_fs.getsyspath('/Dockerfile.template')
adapter_map: dict[str, str] | None = None
adapter_map = None
if adapter_id and not build_ctx:
ctx.fail("'build_ctx' is required when '--adapter-id' is passsed.")
if adapter_id: