diff --git a/openllm-core/src/openllm_core/utils/__init__.py b/openllm-core/src/openllm_core/utils/__init__.py index 471a84ce..58e72f1f 100644 --- a/openllm-core/src/openllm_core/utils/__init__.py +++ b/openllm-core/src/openllm_core/utils/__init__.py @@ -176,18 +176,18 @@ def generate_hash_from_file(f, algorithm='sha1'): return str(getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()) -def getenv(env, default=None, var=None, return_type=t.Any): - env_key = {env.upper(), f'OPENLLM_{env.upper()}'} +def getenv(env, default=None, var=None): + env_key = set([env.upper(), f'OPENLLM_{env.upper()}']) if var is not None: env_key = set(var) | env_key def callback(k: str) -> t.Any: _var = os.getenv(k) - if _var and k.startswith('OPENLLM_'): + if _var is not None and k.startswith('OPENLLM_'): logger.warning("Using '%s' environment is deprecated, use '%s' instead.", k.upper(), k[8:].upper()) return _var - return t.cast(return_type, first_not_none(*(callback(k) for k in env_key), default=default)) + return first_not_none(*(callback(k) for k in env_key), default=default) def field_env_key(key, suffix=None): diff --git a/openllm-core/src/openllm_core/utils/import_utils.py b/openllm-core/src/openllm_core/utils/import_utils.py index 048bb72e..183b81a6 100644 --- a/openllm-core/src/openllm_core/utils/import_utils.py +++ b/openllm-core/src/openllm_core/utils/import_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib, importlib.metadata, importlib.util, os, inspect, typing as t from .codegen import _make_method from ._constants import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES diff --git a/openllm-python/src/_openllm_tiny/_entrypoint.py b/openllm-python/src/_openllm_tiny/_entrypoint.py index 38fac992..f9b71206 100644 --- a/openllm-python/src/_openllm_tiny/_entrypoint.py +++ b/openllm-python/src/_openllm_tiny/_entrypoint.py @@ -269,7 +269,6 @@ def start_command( 'OPENLLM_CONFIG': llm_config.model_dump_json(), 'DTYPE': dtype, 'TRUST_REMOTE_CODE': str(trust_remote_code), - 'MAX_MODEL_LEN': orjson.dumps(max_model_len).decode(), 'GPU_MEMORY_UTILIZATION': orjson.dumps(gpu_memory_utilization).decode(), 'SERVICES_CONFIG': orjson.dumps( dict( @@ -277,13 +276,16 @@ def start_command( ) ).decode(), }) + if max_model_len is not None: + os.environ['MAX_MODEL_LEN'] = orjson.dumps(max_model_len) if quantize: os.environ['QUANTIZE'] = str(quantize) working_dir = os.path.abspath(os.path.dirname(__file__)) if sys.path[0] != working_dir: sys.path.insert(0, working_dir) - load('.', working_dir=working_dir).inject_config() + service = load('.', working_dir=working_dir) + service.inject_config() serve_http('.', working_dir=working_dir) diff --git a/openllm-python/src/_openllm_tiny/_service_vars.py b/openllm-python/src/_openllm_tiny/_service_vars.py index a4de26b8..0d665ace 100644 --- a/openllm-python/src/_openllm_tiny/_service_vars.py +++ b/openllm-python/src/_openllm_tiny/_service_vars.py @@ -1,6 +1,4 @@ -import orjson, typing as t, openllm_core.utils as coreutils -from openllm_core._typing_compat import LiteralSerialisation, LiteralQuantise -from _openllm_tiny._llm import Dtype +import orjson, openllm_core.utils as coreutils ( model_id, @@ -13,18 +11,13 @@ from _openllm_tiny._llm import Dtype gpu_memory_utilization, services_config, ) = ( - coreutils.getenv('model_id', var=['MODEL_ID'], return_type=str), - coreutils.getenv('model_name', return_type=str), - coreutils.getenv('quantize', var=['QUANTISE'], return_type=LiteralQuantise), - coreutils.getenv('serialization', default='safetensors', var=['SERIALISATION'], return_type=LiteralSerialisation), - coreutils.getenv('dtype', default='auto', var=['TORCH_DTYPE'], return_type=Dtype), + coreutils.getenv('model_id', var=['MODEL_ID']), + coreutils.getenv('model_name'), + coreutils.getenv('quantize', var=['QUANTISE']), + coreutils.getenv('serialization', default='safetensors', var=['SERIALISATION']), + coreutils.getenv('dtype', default='auto', var=['TORCH_DTYPE']), coreutils.check_bool_env('TRUST_REMOTE_CODE', False), - t.cast(t.Optional[int], orjson.loads(coreutils.getenv('max_model_len', default=orjson.dumps(None)))), - t.cast( - float, - orjson.loads( - coreutils.getenv('gpu_memory_utilization', default=orjson.dumps(0.9), var=['GPU_MEMORY_UTILISATION']) - ), - ), - t.cast(t.Dict[str, t.Any], orjson.loads(coreutils.getenv('services_config', orjson.dumps({})))), + orjson.loads(coreutils.getenv('max_model_len', default=orjson.dumps(None))), + orjson.loads(coreutils.getenv('gpu_memory_utilization', default=orjson.dumps(0.9), var=['GPU_MEMORY_UTILISATION'])), + orjson.loads(coreutils.getenv('services_config', orjson.dumps({}))), )