diff --git a/openllm-core/src/openllm_core/utils/__init__.py b/openllm-core/src/openllm_core/utils/__init__.py index e163590f..16985dce 100644 --- a/openllm-core/src/openllm_core/utils/__init__.py +++ b/openllm-core/src/openllm_core/utils/__init__.py @@ -371,6 +371,19 @@ _LOGGING_CONFIG = { } +class Counter: + def __init__(self, start: int = 0): + self.counter = start + + def __next__(self) -> int: + i = self.counter + self.counter += 1 + return i + + def reset(self) -> None: + self.counter = 0 + + def configure_logging(): if get_quiet_mode(): _LOGGING_CONFIG['loggers']['openllm']['level'] = logging.ERROR diff --git a/openllm-core/src/openllm_core/utils/__init__.pyi b/openllm-core/src/openllm_core/utils/__init__.pyi index ffc06bd3..6cb84925 100644 --- a/openllm-core/src/openllm_core/utils/__init__.pyi +++ b/openllm-core/src/openllm_core/utils/__init__.pyi @@ -114,3 +114,8 @@ def generate_context(framework_name: str) -> ModelContext: ... def in_notebook() -> bool: ... def flatten_attrs(**attrs: Any) -> Tuple[Dict[str, Any], Dict[str, Any]]: ... def configure_logging() -> None: ... + +class Counter: + counter: int + def __next__(self) -> int: ... + def reset(self) -> None: ... diff --git a/openllm-python/src/_openllm_tiny/_llm.py b/openllm-python/src/_openllm_tiny/_llm.py index 895d6e30..f314fde6 100644 --- a/openllm-python/src/_openllm_tiny/_llm.py +++ b/openllm-python/src/_openllm_tiny/_llm.py @@ -8,6 +8,7 @@ from openllm_core.utils import ( normalise_model_name, gen_random_uuid, dict_filter_none, + Counter, ) from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype from openllm_core._schemas import GenerationOutput @@ -15,7 +16,7 @@ from openllm_core._schemas import GenerationOutput Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']] if t.TYPE_CHECKING: - from vllm import RequestOutput + from vllm import AsyncEngineArgs, EngineArgs, RequestOutput def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.Any]) -> dict[str, t.Any]: @@ -28,6 +29,12 @@ def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t. return v +def check_quantization(_, attr: attr.Attribute[LiteralQuantise], v: str | None) -> LiteralQuantise | None: + if v is not None and v not in {'gptq', 'awq', 'squeezellm'}: + raise ValueError(f'Invalid quantization method: {v}') + return v + + @attr.define(init=False) class LLM: model_id: str @@ -35,80 +42,50 @@ class LLM: serialisation: LiteralSerialisation config: openllm_core.LLMConfig dtype: Dtype - quantise: t.Optional[LiteralQuantise] = attr.field(default=None) + quantise: t.Optional[LiteralQuantise] = attr.field(default=None, validator=check_quantization) trust_remote_code: bool = attr.field(default=False) engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args) + _mode: t.Literal['batch', 'async'] = attr.field(default='async', repr=False) _path: str = attr.field( init=False, default=attr.Factory( lambda self: self.bentomodel.path if self.bentomodel is not None else self.model_id, takes_self=True ), ) + _counter: Counter = attr.field(init=False, factory=lambda: Counter(), repr=False) - def __init__(self, _: str = '', /, _internal: bool = False, **kwargs: t.Any) -> None: + def __init__( + self, _: str = '', /, _internal: bool = False, _mode: t.Literal['batch', 'async'] = 'async', **kwargs: t.Any + ) -> None: if not _internal: raise RuntimeError( 'Cannot instantiate LLM directly in the new API. Make sure to use `openllm.LLM.from_model` instead.' ) + kwargs['mode'] = _mode self.__attrs_init__(**kwargs) def __attrs_post_init__(self): assert self._model # Ensure the model is initialised. - - @functools.cached_property - def _model(self): - if is_vllm_available(): - num_gpus, dev = 1, openllm.utils.device_count() - if dev >= 2: - num_gpus = min(dev // 2 * 2, dev) - quantise = self.quantise if self.quantise and self.quantise in {'gptq', 'awq', 'squeezellm'} else None - dtype = 'float16' if quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet. - - self.engine_args.update({ - 'worker_use_ray': False, - 'engine_use_ray': False, - 'tokenizer_mode': 'auto', - 'tensor_parallel_size': num_gpus, - 'model': self._path, - 'tokenizer': self._path, - 'trust_remote_code': self.trust_remote_code, - 'dtype': dtype, - 'quantization': quantise, - }) - if 'disable_log_stats' not in self.engine_args: - self.engine_args['disable_log_stats'] = not get_debug_mode() - if 'disable_log_requests' not in self.engine_args: - self.engine_args['disable_log_requests'] = not get_debug_mode() - if 'gpu_memory_utilization' not in self.engine_args: - self.engine_args['gpu_memory_utilization'] = 0.9 - - try: - from vllm import AsyncEngineArgs, AsyncLLMEngine - - return AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.engine_args)) - except Exception as err: - traceback.print_exc() - raise openllm.exceptions.OpenLLMException( - f'Failed to initialise vLLMEngine due to the following error:\n{err}' - ) from err - else: - raise RuntimeError('Currently, OpenLLM is only supported running with a GPU and vLLM backend.') + if self.config is None: + self.config = openllm_core.AutoConfig.from_id(self.model_id, trust_remote_code=self.trust_remote_code) @classmethod def from_model( cls, model_id: str, - dtype: str, + dtype: str = 'auto', bentomodel: bentoml.Model | None = None, serialisation: LiteralSerialisation = 'safetensors', quantise: LiteralQuantise | None = None, trust_remote_code: bool = False, llm_config: openllm_core.LLMConfig | None = None, + mode: t.Literal['batch', 'async'] = 'async', **engine_args: t.Any, ) -> LLM: return cls( _internal=True, + _mode=mode, model_id=model_id, bentomodel=bentomodel, quantise=quantise, @@ -131,6 +108,56 @@ class LLM: 'model_id': self.model_id, } + def _get_engine_args(self, mode: t.Literal['batch', 'async'] = 'async') -> AsyncEngineArgs | EngineArgs: + if not is_vllm_available(): + raise RuntimeError("'vllm' is not available. Make sure to install with 'pip install vllm>0.4'") + + from vllm import EngineArgs, AsyncEngineArgs + + num_gpus, dev = 1, openllm.utils.device_count() + if dev >= 2: + num_gpus = min(dev // 2 * 2, dev) + dtype = 'float16' if self.quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet. + + self.engine_args.update({ + 'worker_use_ray': False, + 'tokenizer_mode': 'auto', + 'tensor_parallel_size': num_gpus, + 'model': self._path, + 'tokenizer': self._path, + 'trust_remote_code': self.trust_remote_code, + 'dtype': dtype, + 'quantization': self.quantise, + }) + if 'disable_log_stats' not in self.engine_args: + self.engine_args['disable_log_stats'] = not get_debug_mode() + if 'gpu_memory_utilization' not in self.engine_args: + self.engine_args['gpu_memory_utilization'] = 0.9 + + if mode == 'async': + if 'disable_log_requests' not in self.engine_args: + self.engine_args['disable_log_requests'] = not get_debug_mode() + if 'engine_use_ray' not in self.engine_args: + self.engine_args['engine_use_ray'] = False + + return AsyncEngineArgs(**self.engine_args) if mode == 'async' else EngineArgs(**self.engine_args) + + @functools.cached_property + def _model(self): + if is_vllm_available(): + try: + from vllm import AsyncLLMEngine, LLMEngine + + engine_cls = AsyncLLMEngine if self._mode == 'async' else LLMEngine + return engine_cls.from_engine_args(self._get_engine_args(self._mode)) + except Exception as err: + traceback.print_exc() + raise openllm.exceptions.OpenLLMException( + f'Failed to initialise vLLMEngine due to the following error:\n{err}' + ) from err + else: + raise RuntimeError('Currently, OpenLLM is only supported running with a GPU and vLLM backend.') + @functools.cached_property def _tokenizer(self): import transformers @@ -144,9 +171,13 @@ class LLM: stop: str | t.Iterable[str] | None = None, stop_token_ids: list[int] | None = None, request_id: str | None = None, - adapter_name: str | None = None, **attrs: t.Any, ) -> t.AsyncGenerator[RequestOutput, None]: + if self._mode != 'async': + raise RuntimeError( + "'generate_iterator' is reserved only for online serving. For batch inference use 'LLM.batch' instead." + ) + from vllm import SamplingParams config = self.config.model_construct_env(**dict_filter_none(attrs)) @@ -192,22 +223,68 @@ class LLM: stop: str | t.Iterable[str] | None = None, stop_token_ids: list[int] | None = None, request_id: str | None = None, - adapter_name: str | None = None, **attrs: t.Any, ) -> GenerationOutput: + if self._mode != 'async': + raise RuntimeError( + "'generate' is reserved only for online serving. For batch inference use 'LLM.batch' instead." + ) + if stop is not None: attrs.update({'stop': stop}) if stop_token_ids is not None: attrs.update({'stop_token_ids': stop_token_ids}) config = self.config.model_construct_env(**attrs) async for result in self.generate_iterator( - prompt, - prompt_token_ids=prompt_token_ids, - request_id=request_id, - adapter_name=adapter_name, - **config.model_dump(), + prompt, prompt_token_ids=prompt_token_ids, request_id=request_id, **config.model_dump() ): pass if (final_result := result) is None: raise RuntimeError('No result is returned.') return GenerationOutput.from_vllm(final_result).model_copy(update=dict(prompt=prompt)) + + def batch( + self, + prompts: t.Union[str, t.List[str]], + sampling_params: t.Optional[t.Union[t.Dict[str, t.Any], t.List[t.Dict[str, t.Any]]]] = None, + ) -> t.List[GenerationOutput]: + if self._mode != 'batch': + raise RuntimeError( + "'batch' is reserved for offline batch inference. For online serving use 'LLM.generate' or 'LLM.generate_iterator' instead." + ) + + from vllm import SamplingParams + + if isinstance(prompts, str): + prompts = [prompts] + num_requests = len(prompts) + + if isinstance(sampling_params, list) and len(sampling_params) != num_requests: + raise ValueError('Mismatch number of prompts and sampling params. (They must be the same length)') + + if sampling_params is None: + configs = [self.config] * num_requests + elif isinstance(sampling_params, dict): + configs = [self.config.model_construct_env(**sampling_params)] * num_requests + else: + configs = [self.config.model_construct_env(**it) for it in sampling_params] + + for i in range(num_requests): + request_id = str(next(self._counter)) + prompt = prompts[i] + config = configs[i] + self._model.add_request( + request_id, + prompt, + SamplingParams(**{k: config.__getitem__(k) for k in set(inspect.signature(SamplingParams).parameters.keys())}), + ) + + # now run the engine + outputs = [] + while self._model.has_unfinished_requests(): + step_outputs = self._model.step() + for output in step_outputs: + if output.finished: + outputs.append(GenerationOutput.from_vllm(output)) # noqa + outputs = sorted(outputs, key=lambda x: int(x.request_id)) + return outputs