feat(API): add light support for batch inference (#1004)

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2024-05-31 20:36:12 -04:00
committed by GitHub
parent 162458ffe6
commit 45aceb172f
3 changed files with 145 additions and 50 deletions

View File

@@ -8,6 +8,7 @@ from openllm_core.utils import (
normalise_model_name,
gen_random_uuid,
dict_filter_none,
Counter,
)
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
from openllm_core._schemas import GenerationOutput
@@ -15,7 +16,7 @@ from openllm_core._schemas import GenerationOutput
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
if t.TYPE_CHECKING:
from vllm import RequestOutput
from vllm import AsyncEngineArgs, EngineArgs, RequestOutput
def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.Any]) -> dict[str, t.Any]:
@@ -28,6 +29,12 @@ def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.
return v
def check_quantization(_, attr: attr.Attribute[LiteralQuantise], v: str | None) -> LiteralQuantise | None:
if v is not None and v not in {'gptq', 'awq', 'squeezellm'}:
raise ValueError(f'Invalid quantization method: {v}')
return v
@attr.define(init=False)
class LLM:
model_id: str
@@ -35,80 +42,50 @@ class LLM:
serialisation: LiteralSerialisation
config: openllm_core.LLMConfig
dtype: Dtype
quantise: t.Optional[LiteralQuantise] = attr.field(default=None)
quantise: t.Optional[LiteralQuantise] = attr.field(default=None, validator=check_quantization)
trust_remote_code: bool = attr.field(default=False)
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args)
_mode: t.Literal['batch', 'async'] = attr.field(default='async', repr=False)
_path: str = attr.field(
init=False,
default=attr.Factory(
lambda self: self.bentomodel.path if self.bentomodel is not None else self.model_id, takes_self=True
),
)
_counter: Counter = attr.field(init=False, factory=lambda: Counter(), repr=False)
def __init__(self, _: str = '', /, _internal: bool = False, **kwargs: t.Any) -> None:
def __init__(
self, _: str = '', /, _internal: bool = False, _mode: t.Literal['batch', 'async'] = 'async', **kwargs: t.Any
) -> None:
if not _internal:
raise RuntimeError(
'Cannot instantiate LLM directly in the new API. Make sure to use `openllm.LLM.from_model` instead.'
)
kwargs['mode'] = _mode
self.__attrs_init__(**kwargs)
def __attrs_post_init__(self):
assert self._model # Ensure the model is initialised.
@functools.cached_property
def _model(self):
if is_vllm_available():
num_gpus, dev = 1, openllm.utils.device_count()
if dev >= 2:
num_gpus = min(dev // 2 * 2, dev)
quantise = self.quantise if self.quantise and self.quantise in {'gptq', 'awq', 'squeezellm'} else None
dtype = 'float16' if quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
self.engine_args.update({
'worker_use_ray': False,
'engine_use_ray': False,
'tokenizer_mode': 'auto',
'tensor_parallel_size': num_gpus,
'model': self._path,
'tokenizer': self._path,
'trust_remote_code': self.trust_remote_code,
'dtype': dtype,
'quantization': quantise,
})
if 'disable_log_stats' not in self.engine_args:
self.engine_args['disable_log_stats'] = not get_debug_mode()
if 'disable_log_requests' not in self.engine_args:
self.engine_args['disable_log_requests'] = not get_debug_mode()
if 'gpu_memory_utilization' not in self.engine_args:
self.engine_args['gpu_memory_utilization'] = 0.9
try:
from vllm import AsyncEngineArgs, AsyncLLMEngine
return AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.engine_args))
except Exception as err:
traceback.print_exc()
raise openllm.exceptions.OpenLLMException(
f'Failed to initialise vLLMEngine due to the following error:\n{err}'
) from err
else:
raise RuntimeError('Currently, OpenLLM is only supported running with a GPU and vLLM backend.')
if self.config is None:
self.config = openllm_core.AutoConfig.from_id(self.model_id, trust_remote_code=self.trust_remote_code)
@classmethod
def from_model(
cls,
model_id: str,
dtype: str,
dtype: str = 'auto',
bentomodel: bentoml.Model | None = None,
serialisation: LiteralSerialisation = 'safetensors',
quantise: LiteralQuantise | None = None,
trust_remote_code: bool = False,
llm_config: openllm_core.LLMConfig | None = None,
mode: t.Literal['batch', 'async'] = 'async',
**engine_args: t.Any,
) -> LLM:
return cls(
_internal=True,
_mode=mode,
model_id=model_id,
bentomodel=bentomodel,
quantise=quantise,
@@ -131,6 +108,56 @@ class LLM:
'model_id': self.model_id,
}
def _get_engine_args(self, mode: t.Literal['batch', 'async'] = 'async') -> AsyncEngineArgs | EngineArgs:
if not is_vllm_available():
raise RuntimeError("'vllm' is not available. Make sure to install with 'pip install vllm>0.4'")
from vllm import EngineArgs, AsyncEngineArgs
num_gpus, dev = 1, openllm.utils.device_count()
if dev >= 2:
num_gpus = min(dev // 2 * 2, dev)
dtype = 'float16' if self.quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
self.engine_args.update({
'worker_use_ray': False,
'tokenizer_mode': 'auto',
'tensor_parallel_size': num_gpus,
'model': self._path,
'tokenizer': self._path,
'trust_remote_code': self.trust_remote_code,
'dtype': dtype,
'quantization': self.quantise,
})
if 'disable_log_stats' not in self.engine_args:
self.engine_args['disable_log_stats'] = not get_debug_mode()
if 'gpu_memory_utilization' not in self.engine_args:
self.engine_args['gpu_memory_utilization'] = 0.9
if mode == 'async':
if 'disable_log_requests' not in self.engine_args:
self.engine_args['disable_log_requests'] = not get_debug_mode()
if 'engine_use_ray' not in self.engine_args:
self.engine_args['engine_use_ray'] = False
return AsyncEngineArgs(**self.engine_args) if mode == 'async' else EngineArgs(**self.engine_args)
@functools.cached_property
def _model(self):
if is_vllm_available():
try:
from vllm import AsyncLLMEngine, LLMEngine
engine_cls = AsyncLLMEngine if self._mode == 'async' else LLMEngine
return engine_cls.from_engine_args(self._get_engine_args(self._mode))
except Exception as err:
traceback.print_exc()
raise openllm.exceptions.OpenLLMException(
f'Failed to initialise vLLMEngine due to the following error:\n{err}'
) from err
else:
raise RuntimeError('Currently, OpenLLM is only supported running with a GPU and vLLM backend.')
@functools.cached_property
def _tokenizer(self):
import transformers
@@ -144,9 +171,13 @@ class LLM:
stop: str | t.Iterable[str] | None = None,
stop_token_ids: list[int] | None = None,
request_id: str | None = None,
adapter_name: str | None = None,
**attrs: t.Any,
) -> t.AsyncGenerator[RequestOutput, None]:
if self._mode != 'async':
raise RuntimeError(
"'generate_iterator' is reserved only for online serving. For batch inference use 'LLM.batch' instead."
)
from vllm import SamplingParams
config = self.config.model_construct_env(**dict_filter_none(attrs))
@@ -192,22 +223,68 @@ class LLM:
stop: str | t.Iterable[str] | None = None,
stop_token_ids: list[int] | None = None,
request_id: str | None = None,
adapter_name: str | None = None,
**attrs: t.Any,
) -> GenerationOutput:
if self._mode != 'async':
raise RuntimeError(
"'generate' is reserved only for online serving. For batch inference use 'LLM.batch' instead."
)
if stop is not None:
attrs.update({'stop': stop})
if stop_token_ids is not None:
attrs.update({'stop_token_ids': stop_token_ids})
config = self.config.model_construct_env(**attrs)
async for result in self.generate_iterator(
prompt,
prompt_token_ids=prompt_token_ids,
request_id=request_id,
adapter_name=adapter_name,
**config.model_dump(),
prompt, prompt_token_ids=prompt_token_ids, request_id=request_id, **config.model_dump()
):
pass
if (final_result := result) is None:
raise RuntimeError('No result is returned.')
return GenerationOutput.from_vllm(final_result).model_copy(update=dict(prompt=prompt))
def batch(
self,
prompts: t.Union[str, t.List[str]],
sampling_params: t.Optional[t.Union[t.Dict[str, t.Any], t.List[t.Dict[str, t.Any]]]] = None,
) -> t.List[GenerationOutput]:
if self._mode != 'batch':
raise RuntimeError(
"'batch' is reserved for offline batch inference. For online serving use 'LLM.generate' or 'LLM.generate_iterator' instead."
)
from vllm import SamplingParams
if isinstance(prompts, str):
prompts = [prompts]
num_requests = len(prompts)
if isinstance(sampling_params, list) and len(sampling_params) != num_requests:
raise ValueError('Mismatch number of prompts and sampling params. (They must be the same length)')
if sampling_params is None:
configs = [self.config] * num_requests
elif isinstance(sampling_params, dict):
configs = [self.config.model_construct_env(**sampling_params)] * num_requests
else:
configs = [self.config.model_construct_env(**it) for it in sampling_params]
for i in range(num_requests):
request_id = str(next(self._counter))
prompt = prompts[i]
config = configs[i]
self._model.add_request(
request_id,
prompt,
SamplingParams(**{k: config.__getitem__(k) for k in set(inspect.signature(SamplingParams).parameters.keys())}),
)
# now run the engine
outputs = []
while self._model.has_unfinished_requests():
step_outputs = self._model.step()
for output in step_outputs:
if output.finished:
outputs.append(GenerationOutput.from_vllm(output)) # noqa
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs