diff --git a/openllm-python/src/_openllm_tiny/_entrypoint.py b/openllm-python/src/_openllm_tiny/_entrypoint.py index a8a64bad..a7ecaaeb 100644 --- a/openllm-python/src/_openllm_tiny/_entrypoint.py +++ b/openllm-python/src/_openllm_tiny/_entrypoint.py @@ -69,6 +69,8 @@ def parse_device_callback( # NOTE: --device all is a special case if len(el) == 1 and el[0] == 'all': return tuple(map(str, openllm.utils.available_devices())) + if len(el) == 1 and el[0] == 'gpu': + return ('0',) return el @@ -266,9 +268,7 @@ def start_command( 'TRUST_REMOTE_CODE': str(trust_remote_code), 'GPU_MEMORY_UTILIZATION': orjson.dumps(gpu_memory_utilization).decode(), 'SERVICES_CONFIG': orjson.dumps( - dict( - resources={'gpu' if device else 'cpu': len(device) if device else 'cpu_count'}, traffic=dict(timeout=timeout) - ) + dict(resources={'gpu' if device else 'cpu': len(device) if device else '1'}, traffic=dict(timeout=timeout)) ).decode(), }) if max_model_len is not None: