mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-07 22:33:28 -05:00
fix(load): make sure to respect MAX_MODEL_LEN from env
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -269,7 +269,6 @@ def start_command(
|
||||
'OPENLLM_CONFIG': llm_config.model_dump_json(),
|
||||
'DTYPE': dtype,
|
||||
'TRUST_REMOTE_CODE': str(trust_remote_code),
|
||||
'MAX_MODEL_LEN': orjson.dumps(max_model_len).decode(),
|
||||
'GPU_MEMORY_UTILIZATION': orjson.dumps(gpu_memory_utilization).decode(),
|
||||
'SERVICES_CONFIG': orjson.dumps(
|
||||
dict(
|
||||
@@ -277,13 +276,16 @@ def start_command(
|
||||
)
|
||||
).decode(),
|
||||
})
|
||||
if max_model_len is not None:
|
||||
os.environ['MAX_MODEL_LEN'] = orjson.dumps(max_model_len)
|
||||
if quantize:
|
||||
os.environ['QUANTIZE'] = str(quantize)
|
||||
|
||||
working_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
if sys.path[0] != working_dir:
|
||||
sys.path.insert(0, working_dir)
|
||||
load('.', working_dir=working_dir).inject_config()
|
||||
service = load('.', working_dir=working_dir)
|
||||
service.inject_config()
|
||||
serve_http('.', working_dir=working_dir)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import orjson, typing as t, openllm_core.utils as coreutils
|
||||
from openllm_core._typing_compat import LiteralSerialisation, LiteralQuantise
|
||||
from _openllm_tiny._llm import Dtype
|
||||
import orjson, openllm_core.utils as coreutils
|
||||
|
||||
(
|
||||
model_id,
|
||||
@@ -13,18 +11,13 @@ from _openllm_tiny._llm import Dtype
|
||||
gpu_memory_utilization,
|
||||
services_config,
|
||||
) = (
|
||||
coreutils.getenv('model_id', var=['MODEL_ID'], return_type=str),
|
||||
coreutils.getenv('model_name', return_type=str),
|
||||
coreutils.getenv('quantize', var=['QUANTISE'], return_type=LiteralQuantise),
|
||||
coreutils.getenv('serialization', default='safetensors', var=['SERIALISATION'], return_type=LiteralSerialisation),
|
||||
coreutils.getenv('dtype', default='auto', var=['TORCH_DTYPE'], return_type=Dtype),
|
||||
coreutils.getenv('model_id', var=['MODEL_ID']),
|
||||
coreutils.getenv('model_name'),
|
||||
coreutils.getenv('quantize', var=['QUANTISE']),
|
||||
coreutils.getenv('serialization', default='safetensors', var=['SERIALISATION']),
|
||||
coreutils.getenv('dtype', default='auto', var=['TORCH_DTYPE']),
|
||||
coreutils.check_bool_env('TRUST_REMOTE_CODE', False),
|
||||
t.cast(t.Optional[int], orjson.loads(coreutils.getenv('max_model_len', default=orjson.dumps(None)))),
|
||||
t.cast(
|
||||
float,
|
||||
orjson.loads(
|
||||
coreutils.getenv('gpu_memory_utilization', default=orjson.dumps(0.9), var=['GPU_MEMORY_UTILISATION'])
|
||||
),
|
||||
),
|
||||
t.cast(t.Dict[str, t.Any], orjson.loads(coreutils.getenv('services_config', orjson.dumps({})))),
|
||||
orjson.loads(coreutils.getenv('max_model_len', default=orjson.dumps(None))),
|
||||
orjson.loads(coreutils.getenv('gpu_memory_utilization', default=orjson.dumps(0.9), var=['GPU_MEMORY_UTILISATION'])),
|
||||
orjson.loads(coreutils.getenv('services_config', orjson.dumps({}))),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user