mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-18 12:29:38 -05:00
feat(API): add light support for batch inference (#1004)
Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from openllm_core.utils import (
|
||||
normalise_model_name,
|
||||
gen_random_uuid,
|
||||
dict_filter_none,
|
||||
Counter,
|
||||
)
|
||||
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
@@ -15,7 +16,7 @@ from openllm_core._schemas import GenerationOutput
|
||||
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from vllm import RequestOutput
|
||||
from vllm import AsyncEngineArgs, EngineArgs, RequestOutput
|
||||
|
||||
|
||||
def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.Any]) -> dict[str, t.Any]:
|
||||
@@ -28,6 +29,12 @@ def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.
|
||||
return v
|
||||
|
||||
|
||||
def check_quantization(_, attr: attr.Attribute[LiteralQuantise], v: str | None) -> LiteralQuantise | None:
|
||||
if v is not None and v not in {'gptq', 'awq', 'squeezellm'}:
|
||||
raise ValueError(f'Invalid quantization method: {v}')
|
||||
return v
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class LLM:
|
||||
model_id: str
|
||||
@@ -35,80 +42,50 @@ class LLM:
|
||||
serialisation: LiteralSerialisation
|
||||
config: openllm_core.LLMConfig
|
||||
dtype: Dtype
|
||||
quantise: t.Optional[LiteralQuantise] = attr.field(default=None)
|
||||
quantise: t.Optional[LiteralQuantise] = attr.field(default=None, validator=check_quantization)
|
||||
trust_remote_code: bool = attr.field(default=False)
|
||||
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args)
|
||||
|
||||
_mode: t.Literal['batch', 'async'] = attr.field(default='async', repr=False)
|
||||
_path: str = attr.field(
|
||||
init=False,
|
||||
default=attr.Factory(
|
||||
lambda self: self.bentomodel.path if self.bentomodel is not None else self.model_id, takes_self=True
|
||||
),
|
||||
)
|
||||
_counter: Counter = attr.field(init=False, factory=lambda: Counter(), repr=False)
|
||||
|
||||
def __init__(self, _: str = '', /, _internal: bool = False, **kwargs: t.Any) -> None:
|
||||
def __init__(
|
||||
self, _: str = '', /, _internal: bool = False, _mode: t.Literal['batch', 'async'] = 'async', **kwargs: t.Any
|
||||
) -> None:
|
||||
if not _internal:
|
||||
raise RuntimeError(
|
||||
'Cannot instantiate LLM directly in the new API. Make sure to use `openllm.LLM.from_model` instead.'
|
||||
)
|
||||
kwargs['mode'] = _mode
|
||||
self.__attrs_init__(**kwargs)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
assert self._model # Ensure the model is initialised.
|
||||
|
||||
@functools.cached_property
|
||||
def _model(self):
|
||||
if is_vllm_available():
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
quantise = self.quantise if self.quantise and self.quantise in {'gptq', 'awq', 'squeezellm'} else None
|
||||
dtype = 'float16' if quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
|
||||
|
||||
self.engine_args.update({
|
||||
'worker_use_ray': False,
|
||||
'engine_use_ray': False,
|
||||
'tokenizer_mode': 'auto',
|
||||
'tensor_parallel_size': num_gpus,
|
||||
'model': self._path,
|
||||
'tokenizer': self._path,
|
||||
'trust_remote_code': self.trust_remote_code,
|
||||
'dtype': dtype,
|
||||
'quantization': quantise,
|
||||
})
|
||||
if 'disable_log_stats' not in self.engine_args:
|
||||
self.engine_args['disable_log_stats'] = not get_debug_mode()
|
||||
if 'disable_log_requests' not in self.engine_args:
|
||||
self.engine_args['disable_log_requests'] = not get_debug_mode()
|
||||
if 'gpu_memory_utilization' not in self.engine_args:
|
||||
self.engine_args['gpu_memory_utilization'] = 0.9
|
||||
|
||||
try:
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine
|
||||
|
||||
return AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.engine_args))
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed to initialise vLLMEngine due to the following error:\n{err}'
|
||||
) from err
|
||||
else:
|
||||
raise RuntimeError('Currently, OpenLLM is only supported running with a GPU and vLLM backend.')
|
||||
if self.config is None:
|
||||
self.config = openllm_core.AutoConfig.from_id(self.model_id, trust_remote_code=self.trust_remote_code)
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls,
|
||||
model_id: str,
|
||||
dtype: str,
|
||||
dtype: str = 'auto',
|
||||
bentomodel: bentoml.Model | None = None,
|
||||
serialisation: LiteralSerialisation = 'safetensors',
|
||||
quantise: LiteralQuantise | None = None,
|
||||
trust_remote_code: bool = False,
|
||||
llm_config: openllm_core.LLMConfig | None = None,
|
||||
mode: t.Literal['batch', 'async'] = 'async',
|
||||
**engine_args: t.Any,
|
||||
) -> LLM:
|
||||
return cls(
|
||||
_internal=True,
|
||||
_mode=mode,
|
||||
model_id=model_id,
|
||||
bentomodel=bentomodel,
|
||||
quantise=quantise,
|
||||
@@ -131,6 +108,56 @@ class LLM:
|
||||
'model_id': self.model_id,
|
||||
}
|
||||
|
||||
def _get_engine_args(self, mode: t.Literal['batch', 'async'] = 'async') -> AsyncEngineArgs | EngineArgs:
|
||||
if not is_vllm_available():
|
||||
raise RuntimeError("'vllm' is not available. Make sure to install with 'pip install vllm>0.4'")
|
||||
|
||||
from vllm import EngineArgs, AsyncEngineArgs
|
||||
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
dtype = 'float16' if self.quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
|
||||
|
||||
self.engine_args.update({
|
||||
'worker_use_ray': False,
|
||||
'tokenizer_mode': 'auto',
|
||||
'tensor_parallel_size': num_gpus,
|
||||
'model': self._path,
|
||||
'tokenizer': self._path,
|
||||
'trust_remote_code': self.trust_remote_code,
|
||||
'dtype': dtype,
|
||||
'quantization': self.quantise,
|
||||
})
|
||||
if 'disable_log_stats' not in self.engine_args:
|
||||
self.engine_args['disable_log_stats'] = not get_debug_mode()
|
||||
if 'gpu_memory_utilization' not in self.engine_args:
|
||||
self.engine_args['gpu_memory_utilization'] = 0.9
|
||||
|
||||
if mode == 'async':
|
||||
if 'disable_log_requests' not in self.engine_args:
|
||||
self.engine_args['disable_log_requests'] = not get_debug_mode()
|
||||
if 'engine_use_ray' not in self.engine_args:
|
||||
self.engine_args['engine_use_ray'] = False
|
||||
|
||||
return AsyncEngineArgs(**self.engine_args) if mode == 'async' else EngineArgs(**self.engine_args)
|
||||
|
||||
@functools.cached_property
|
||||
def _model(self):
|
||||
if is_vllm_available():
|
||||
try:
|
||||
from vllm import AsyncLLMEngine, LLMEngine
|
||||
|
||||
engine_cls = AsyncLLMEngine if self._mode == 'async' else LLMEngine
|
||||
return engine_cls.from_engine_args(self._get_engine_args(self._mode))
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed to initialise vLLMEngine due to the following error:\n{err}'
|
||||
) from err
|
||||
else:
|
||||
raise RuntimeError('Currently, OpenLLM is only supported running with a GPU and vLLM backend.')
|
||||
|
||||
@functools.cached_property
|
||||
def _tokenizer(self):
|
||||
import transformers
|
||||
@@ -144,9 +171,13 @@ class LLM:
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.AsyncGenerator[RequestOutput, None]:
|
||||
if self._mode != 'async':
|
||||
raise RuntimeError(
|
||||
"'generate_iterator' is reserved only for online serving. For batch inference use 'LLM.batch' instead."
|
||||
)
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
config = self.config.model_construct_env(**dict_filter_none(attrs))
|
||||
@@ -192,22 +223,68 @@ class LLM:
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> GenerationOutput:
|
||||
if self._mode != 'async':
|
||||
raise RuntimeError(
|
||||
"'generate' is reserved only for online serving. For batch inference use 'LLM.batch' instead."
|
||||
)
|
||||
|
||||
if stop is not None:
|
||||
attrs.update({'stop': stop})
|
||||
if stop_token_ids is not None:
|
||||
attrs.update({'stop_token_ids': stop_token_ids})
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
async for result in self.generate_iterator(
|
||||
prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id,
|
||||
adapter_name=adapter_name,
|
||||
**config.model_dump(),
|
||||
prompt, prompt_token_ids=prompt_token_ids, request_id=request_id, **config.model_dump()
|
||||
):
|
||||
pass
|
||||
if (final_result := result) is None:
|
||||
raise RuntimeError('No result is returned.')
|
||||
return GenerationOutput.from_vllm(final_result).model_copy(update=dict(prompt=prompt))
|
||||
|
||||
def batch(
|
||||
self,
|
||||
prompts: t.Union[str, t.List[str]],
|
||||
sampling_params: t.Optional[t.Union[t.Dict[str, t.Any], t.List[t.Dict[str, t.Any]]]] = None,
|
||||
) -> t.List[GenerationOutput]:
|
||||
if self._mode != 'batch':
|
||||
raise RuntimeError(
|
||||
"'batch' is reserved for offline batch inference. For online serving use 'LLM.generate' or 'LLM.generate_iterator' instead."
|
||||
)
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
num_requests = len(prompts)
|
||||
|
||||
if isinstance(sampling_params, list) and len(sampling_params) != num_requests:
|
||||
raise ValueError('Mismatch number of prompts and sampling params. (They must be the same length)')
|
||||
|
||||
if sampling_params is None:
|
||||
configs = [self.config] * num_requests
|
||||
elif isinstance(sampling_params, dict):
|
||||
configs = [self.config.model_construct_env(**sampling_params)] * num_requests
|
||||
else:
|
||||
configs = [self.config.model_construct_env(**it) for it in sampling_params]
|
||||
|
||||
for i in range(num_requests):
|
||||
request_id = str(next(self._counter))
|
||||
prompt = prompts[i]
|
||||
config = configs[i]
|
||||
self._model.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
SamplingParams(**{k: config.__getitem__(k) for k in set(inspect.signature(SamplingParams).parameters.keys())}),
|
||||
)
|
||||
|
||||
# now run the engine
|
||||
outputs = []
|
||||
while self._model.has_unfinished_requests():
|
||||
step_outputs = self._model.step()
|
||||
for output in step_outputs:
|
||||
if output.finished:
|
||||
outputs.append(GenerationOutput.from_vllm(output)) # noqa
|
||||
outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
||||
return outputs
|
||||
|
||||
Reference in New Issue
Block a user