Files
OpenLLM/openllm-python/src/_openllm_tiny/_llm.py
2024-04-02 04:35:47 +00:00

225 lines
7.7 KiB
Python

from __future__ import annotations
import inspect, orjson, dataclasses, bentoml, functools, attr, openllm_core, traceback, openllm, typing as t
from openllm_core.utils import (
get_debug_mode,
is_vllm_available,
normalise_model_name,
gen_random_uuid,
dict_filter_none,
)
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
from openllm_core._schemas import GenerationOutput, GenerationInput
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
if t.TYPE_CHECKING:
from vllm import RequestOutput
def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.Any]) -> dict[str, t.Any]:
from vllm import AsyncEngineArgs
fields = dataclasses.fields(AsyncEngineArgs)
invalid_args = {k: v for k, v in v.items() if k not in {f.name for f in fields}}
if len(invalid_args) > 0:
raise ValueError(f'Invalid engine args: {list(invalid_args)}')
return v
@attr.define(init=False)
class LLM:
model_id: str
bentomodel: bentoml.Model
serialisation: LiteralSerialisation
config: openllm_core.LLMConfig
dtype: Dtype
quantise: t.Optional[LiteralQuantise] = attr.field(default=None)
trust_remote_code: bool = attr.field(default=False)
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args)
_path: str = attr.field(
init=False,
default=attr.Factory(
lambda self: self.bentomodel.path if self.bentomodel is not None else self.model_id, takes_self=True
),
)
def __init__(self, _: str = '', /, _internal: bool = False, **kwargs: t.Any) -> None:
if not _internal:
raise RuntimeError(
'Cannot instantiate LLM directly in the new API. Make sure to use `openllm.LLM.from_model` instead.'
)
self.__attrs_init__(**kwargs)
def __attrs_post_init__(self):
assert self._model # Ensure the model is initialised.
@functools.cached_property
def _model(self):
if is_vllm_available():
num_gpus, dev = 1, openllm.utils.device_count()
if dev >= 2:
num_gpus = min(dev // 2 * 2, dev)
quantise = self.quantise if self.quantise and self.quantise in {'gptq', 'awq', 'squeezellm'} else None
dtype = 'float16' if quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
self.engine_args.update({
'worker_use_ray': False,
'engine_use_ray': False,
'tokenizer_mode': 'auto',
'tensor_parallel_size': num_gpus,
'model': self._path,
'tokenizer': self._path,
'trust_remote_code': self.trust_remote_code,
'dtype': dtype,
'quantization': quantise,
})
if 'disable_log_stats' not in self.engine_args:
self.engine_args['disable_log_stats'] = not get_debug_mode()
if 'disable_log_requests' not in self.engine_args:
self.engine_args['disable_log_requests'] = not get_debug_mode()
if 'gpu_memory_utilization' not in self.engine_args:
self.engine_args['gpu_memory_utilization'] = 0.9
try:
from vllm import AsyncEngineArgs, AsyncLLMEngine
return AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.engine_args))
except Exception as err:
traceback.print_exc()
raise openllm.exceptions.OpenLLMException(
f'Failed to initialise vLLMEngine due to the following error:\n{err}'
) from err
else:
raise RuntimeError('Currently, OpenLLM is only supported running with a GPU and vLLM backend.')
@classmethod
def from_model(
cls,
model_id: str,
dtype: str,
bentomodel: bentoml.Model | None = None,
serialisation: LiteralSerialisation = 'safetensors',
quantise: LiteralQuantise | None = None,
trust_remote_code: bool = False,
llm_config: openllm_core.LLMConfig | None = None,
**engine_args: t.Any,
) -> LLM:
return cls(
_internal=True,
model_id=model_id,
bentomodel=bentomodel,
quantise=quantise,
serialisation=serialisation,
config=llm_config,
dtype=dtype,
engine_args=engine_args,
trust_remote_code=trust_remote_code,
)
@property
def llm_type(self) -> str:
return normalise_model_name(self.model_id)
@property
def identifying_params(self) -> dict[str, str]:
return {
'configuration': self.config.model_dump_json(),
'model_ids': orjson.dumps(self.config['model_ids']).decode(),
'model_id': self.model_id,
}
@functools.cached_property
def _tokenizer(self):
import transformers
return transformers.AutoTokenizer.from_pretrained(self._path, trust_remote_code=self.trust_remote_code)
async def generate_iterator(
self,
prompt: str | None,
prompt_token_ids: list[int] | None = None,
stop: str | t.Iterable[str] | None = None,
stop_token_ids: list[int] | None = None,
request_id: str | None = None,
adapter_name: str | None = None,
**attrs: t.Any,
) -> t.AsyncGenerator[RequestOutput, None]:
from vllm import SamplingParams
config = self.config.generation_config.model_copy(update=dict_filter_none(attrs))
stop_token_ids = stop_token_ids or []
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
if eos_token_id and not isinstance(eos_token_id, list):
eos_token_id = [eos_token_id]
stop_token_ids.extend(eos_token_id or [])
if (config_eos := config['eos_token_id']) and config_eos not in stop_token_ids:
stop_token_ids.append(config_eos)
if self._tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(self._tokenizer.eos_token_id)
if stop is None:
stop = set()
elif isinstance(stop, str):
stop = {stop}
else:
stop = set(stop)
request_id = gen_random_uuid() if request_id is None else request_id
top_p = 1.0 if config['temperature'] <= 1e-5 else config['top_p']
config = config.model_copy(update=dict(stop=list(stop), stop_token_ids=stop_token_ids, top_p=top_p))
params = {k: getattr(config, k, None) for k in set(inspect.signature(SamplingParams).parameters.keys())}
sampling_params = SamplingParams(**{k: v for k, v in params.items() if v is not None})
try:
async for it in self._model.generate(
prompt, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_token_ids
):
yield it
except Exception as err:
raise RuntimeError(f'Failed to start generation task: {err}') from err
async def generate(
self,
prompt: str | None,
prompt_token_ids: list[int] | None = None,
stop: str | t.Iterable[str] | None = None,
stop_token_ids: list[int] | None = None,
request_id: str | None = None,
adapter_name: str | None = None,
*,
_generated: GenerationInput | None = None,
**attrs: t.Any,
) -> GenerationOutput:
if stop is not None:
attrs.update({'stop': stop})
if stop_token_ids is not None:
attrs.update({'stop_token_ids': stop_token_ids})
config = self.config.model_copy(update=attrs)
texts, token_ids = [[]] * config['n'], [[]] * config['n']
async for result in self.generate_iterator(
prompt,
prompt_token_ids=prompt_token_ids,
request_id=request_id,
adapter_name=adapter_name,
**config.model_dump(),
):
for output in result.outputs:
texts[output.index].append(output.text)
token_ids[output.index].extend(output.token_ids)
if (final_result := result) is None:
raise RuntimeError('No result is returned.')
return GenerationOutput.from_vllm(final_result).model_copy(
update=dict(
prompt=prompt,
outputs=[
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
for output in final_result.outputs
],
)
)