fix(load): make sure to respect MAX_MODEL_LEN from env

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2024-03-21 07:44:49 +00:00
parent 295a3b1061
commit 51bec78ee9
4 changed files with 19 additions and 22 deletions

View File

@@ -269,7 +269,6 @@ def start_command(
'OPENLLM_CONFIG': llm_config.model_dump_json(),
'DTYPE': dtype,
'TRUST_REMOTE_CODE': str(trust_remote_code),
'MAX_MODEL_LEN': orjson.dumps(max_model_len).decode(),
'GPU_MEMORY_UTILIZATION': orjson.dumps(gpu_memory_utilization).decode(),
'SERVICES_CONFIG': orjson.dumps(
dict(
@@ -277,13 +276,16 @@ def start_command(
)
).decode(),
})
if max_model_len is not None:
os.environ['MAX_MODEL_LEN'] = orjson.dumps(max_model_len)
if quantize:
os.environ['QUANTIZE'] = str(quantize)
working_dir = os.path.abspath(os.path.dirname(__file__))
if sys.path[0] != working_dir:
sys.path.insert(0, working_dir)
load('.', working_dir=working_dir).inject_config()
service = load('.', working_dir=working_dir)
service.inject_config()
serve_http('.', working_dir=working_dir)

View File

@@ -1,6 +1,4 @@
import orjson, typing as t, openllm_core.utils as coreutils
from openllm_core._typing_compat import LiteralSerialisation, LiteralQuantise
from _openllm_tiny._llm import Dtype
import orjson, openllm_core.utils as coreutils
(
model_id,
@@ -13,18 +11,13 @@ from _openllm_tiny._llm import Dtype
gpu_memory_utilization,
services_config,
) = (
coreutils.getenv('model_id', var=['MODEL_ID'], return_type=str),
coreutils.getenv('model_name', return_type=str),
coreutils.getenv('quantize', var=['QUANTISE'], return_type=LiteralQuantise),
coreutils.getenv('serialization', default='safetensors', var=['SERIALISATION'], return_type=LiteralSerialisation),
coreutils.getenv('dtype', default='auto', var=['TORCH_DTYPE'], return_type=Dtype),
coreutils.getenv('model_id', var=['MODEL_ID']),
coreutils.getenv('model_name'),
coreutils.getenv('quantize', var=['QUANTISE']),
coreutils.getenv('serialization', default='safetensors', var=['SERIALISATION']),
coreutils.getenv('dtype', default='auto', var=['TORCH_DTYPE']),
coreutils.check_bool_env('TRUST_REMOTE_CODE', False),
t.cast(t.Optional[int], orjson.loads(coreutils.getenv('max_model_len', default=orjson.dumps(None)))),
t.cast(
float,
orjson.loads(
coreutils.getenv('gpu_memory_utilization', default=orjson.dumps(0.9), var=['GPU_MEMORY_UTILISATION'])
),
),
t.cast(t.Dict[str, t.Any], orjson.loads(coreutils.getenv('services_config', orjson.dumps({})))),
orjson.loads(coreutils.getenv('max_model_len', default=orjson.dumps(None))),
orjson.loads(coreutils.getenv('gpu_memory_utilization', default=orjson.dumps(0.9), var=['GPU_MEMORY_UTILISATION'])),
orjson.loads(coreutils.getenv('services_config', orjson.dumps({}))),
)