Files
OpenLLM/openllm-python/src/openllm/_llm.py
Aaron Pham 072b3e97ec feat: 1.2 APIs (#821)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-03-15 03:49:19 -04:00

555 lines
19 KiB
Python

from __future__ import annotations
import functools, logging, os, warnings, typing as t
import attr, orjson, bentoml, openllm, openllm_core
from openllm_core._schemas import GenerationOutput
from openllm_core._typing_compat import (
AdapterMap,
AdapterTuple,
AdapterType,
LiteralBackend,
LiteralDtype,
LiteralQuantise,
LiteralSerialisation,
)
from openllm.serialisation import _make_tag_components
from openllm_core.exceptions import MissingDependencyError
from openllm_core.utils import (
DEBUG,
check_bool_env,
codegen,
first_not_none,
normalise_model_name,
flatten_attrs,
gen_random_uuid,
getenv,
is_peft_available,
is_transformers_available,
is_vllm_available,
resolve_filepath,
validate_is_path,
)
from .exceptions import ForbiddenAttributeError, OpenLLMException
from .serialisation.constants import PEFT_CONFIG_NAME
if t.TYPE_CHECKING:
import torch, transformers
from peft.config import PeftConfig
from openllm_core._configuration import LLMConfig
from ._runners import Runner
logger = logging.getLogger(__name__)
_AdapterTuple: type[AdapterTuple] = codegen.make_attr_tuple_class('AdapterTuple', ['adapter_id', 'name', 'config'])
ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['PeftConfig', str]]]
CONFIG_FILE_NAME = 'config.json'
M = t.TypeVar('M')
T = t.TypeVar('T')
@attr.define(slots=False, repr=False, init=False)
class LLM(t.Generic[M, T]):
async def generate(
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
):
if adapter_name is not None and self.__llm_backend__ != 'pt':
raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.')
if stop is not None:
attrs.update({'stop': stop})
if stop_token_ids is not None:
attrs.update({'stop_token_ids': stop_token_ids})
config = self.config.model_construct_env(**attrs)
texts, token_ids = [[]] * config['n'], [[]] * config['n']
async for result in self.generate_iterator(
prompt,
prompt_token_ids=prompt_token_ids,
request_id=request_id,
adapter_name=adapter_name,
**config.model_dump(),
):
for output in result.outputs:
texts[output.index].append(output.text)
token_ids[output.index].extend(output.token_ids)
if (final_result := result) is None:
raise RuntimeError('No result is returned.')
return final_result.model_copy(
update=dict(
prompt=prompt,
outputs=[
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
for output in final_result.outputs
],
)
)
async def generate_iterator(
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
):
from bentoml._internal.runner.runner_handle import DummyRunnerHandle
if adapter_name is not None and self.__llm_backend__ != 'pt':
raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.')
if isinstance(self.runner._runner_handle, DummyRunnerHandle):
if os.getenv('BENTO_PATH') is not None:
raise RuntimeError('Runner client failed to set up correctly.')
else:
self.runner.init_local(quiet=True)
config = self.config.model_construct_env(**attrs)
stop_token_ids = stop_token_ids or []
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
if eos_token_id and not isinstance(eos_token_id, list):
eos_token_id = [eos_token_id]
stop_token_ids.extend(eos_token_id or [])
if (config_eos := config['eos_token_id']) and config_eos not in stop_token_ids:
stop_token_ids.append(config_eos)
if self.tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(self.tokenizer.eos_token_id)
if stop is None:
stop = set()
elif isinstance(stop, str):
stop = {stop}
else:
stop = set(stop)
request_id = gen_random_uuid() if request_id is None else request_id
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
config = config.model_construct_env(stop=list(stop), stop_token_ids=stop_token_ids)
try:
generator = bentoml.io.SSE.from_iterator(
self.runner.generate_iterator.async_stream(
prompt, request_id, prompt_token_ids=prompt_token_ids, adapter_name=adapter_name, **config.model_dump()
)
)
except Exception as err:
raise RuntimeError(f'Failed to start generation task: {err}') from err
try:
async for out in generator:
out = out.data
generated = GenerationOutput.from_runner(out).model_copy(update=dict(prompt=prompt))
delta_outputs = [None] * len(generated.outputs)
for output in generated.outputs:
i = output.index
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
delta_outputs[i] = output.model_copy(update=dict(text=delta_text, token_ids=delta_tokens))
yield generated.model_copy(update=dict(outputs=delta_outputs))
except Exception as err:
raise RuntimeError(f'Exception caught during generation: {err}') from err
# NOTE: If you are here to see how generate_iterator and generate works, see above.
# The below are mainly for internal implementation that you don't have to worry about.
_model_id: str
_revision: t.Optional[str] #
_quantization_config: t.Optional[
t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]
]
_quantise: t.Optional[LiteralQuantise]
_model_decls: t.Tuple[t.Any, ...]
__model_attrs: t.Dict[str, t.Any] #
__tokenizer_attrs: t.Dict[str, t.Any]
_tag: bentoml.Tag
_adapter_map: t.Optional[AdapterMap] #
_serialisation: LiteralSerialisation
_local: bool
_max_model_len: t.Optional[int] #
_gpu_memory_utilization: float
__llm_dtype__: t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']] = 'auto'
__llm_torch_dtype__: 'torch.dtype' = None
__llm_config__: t.Optional[LLMConfig] = None
__llm_backend__: LiteralBackend = None
__llm_quantization_config__: t.Optional[
t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]
] = None
__llm_runner__: t.Optional[Runner[M, T]] = None
__llm_model__: t.Optional[M] = None
__llm_tokenizer__: t.Optional[T] = None
__llm_adapter_map__: t.Optional[ResolvedAdapterMap] = None
__llm_trust_remote_code__: bool = False
def __init__(
self,
model_id,
model_version=None,
model_tag=None,
llm_config=None,
backend=None,
*args,
quantize=None,
quantization_config=None,
adapter_map=None,
serialisation='safetensors',
trust_remote_code=False,
embedded=False,
dtype='auto',
low_cpu_mem_usage=True,
max_model_len=None,
gpu_memory_utilization=0.9,
_eager=True,
**attrs,
):
torch_dtype = attrs.pop('torch_dtype', None) # backward compatible
if torch_dtype is not None:
warnings.warn(
'The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.',
DeprecationWarning,
stacklevel=3,
)
dtype = torch_dtype
_local = False
if validate_is_path(model_id):
model_id, _local = resolve_filepath(model_id), True
backend = getenv('backend', default=backend)
if backend is None:
backend = self._cascade_backend()
dtype = getenv('dtype', default=dtype, var=['TORCH_DTYPE'])
if dtype is None:
logger.warning('Setting dtype to auto. Inferring from framework specific models')
dtype = 'auto'
quantize = getenv('quantize', default=quantize, var=['QUANITSE'])
attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage})
# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
model_attrs, tokenizer_attrs = flatten_attrs(**attrs)
if model_tag is None:
model_tag, model_version = _make_tag_components(model_id, model_version)
if model_version:
model_tag = f'{model_tag}:{model_version}'
self.__attrs_init__(
model_id=model_id,
revision=model_version,
tag=bentoml.Tag.from_taglike(model_tag),
quantization_config=quantization_config,
quantise=getattr(self._Quantise, backend)(self, quantize),
model_decls=args,
adapter_map=convert_peft_config_type(adapter_map) if adapter_map is not None else None,
serialisation=serialisation,
local=_local,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
LLM__model_attrs=model_attrs,
LLM__tokenizer_attrs=tokenizer_attrs,
llm_dtype__=dtype.lower(),
llm_backend__=backend,
llm_config__=llm_config,
llm_trust_remote_code__=trust_remote_code,
)
if _eager:
try:
model = bentoml.models.get(self.tag)
except bentoml.exceptions.NotFound:
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
# resolve the tag
self._tag = model.tag
if not _eager and embedded:
raise RuntimeError("Embedded mode is not supported when '_eager' is False.")
if embedded:
logger.warning('NOT RECOMMENDED in production and SHOULD ONLY used for development.')
self.runner.init_local(quiet=True)
class _Quantise:
@staticmethod
def pt(llm: LLM, quantise=None):
return quantise
@staticmethod
def vllm(llm: LLM, quantise=None):
return quantise
@functools.cached_property
def _has_gpus(self):
try:
from cuda import cuda
err, *_ = cuda.cuInit(0)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError('Failed to initialise CUDA runtime binding.')
err, _ = cuda.cuDeviceGetCount()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError('Failed to get CUDA device count.')
return True
except (ImportError, RuntimeError):
return False
@property
def _torch_dtype(self):
import torch, transformers
_map = _torch_dtype_mapping()
if not isinstance(self.__llm_torch_dtype__, torch.dtype):
hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
config_dtype = getattr(hf_config, 'torch_dtype', torch.float32)
if self.__llm_dtype__ == 'auto':
if torch.cuda.is_available() and config_dtype is torch.float32:
torch_dtype = torch.float16
elif not torch.cuda.is_available():
torch_dtype = torch.float32
else:
torch_dtype = config_dtype
else:
if self.__llm_dtype__ not in _map:
raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'")
torch_dtype = _map[self.__llm_dtype__]
self.__llm_torch_dtype__ = torch_dtype
return self.__llm_torch_dtype__
@property
def _model_attrs(self):
return {**self.import_kwargs[0], **self.__model_attrs}
@_model_attrs.setter
def _model_attrs(self, model_attrs):
self.__model_attrs = model_attrs
@property
def _tokenizer_attrs(self):
return {**self.import_kwargs[1], **self.__tokenizer_attrs}
def _cascade_backend(self) -> LiteralBackend:
logger.warning(
'It is recommended to specify the backend explicitly. Cascading backend might lead to unexpected behaviour.'
)
if self._has_gpus and is_vllm_available():
return 'vllm'
else:
return 'pt'
def __setattr__(self, attr, value):
if attr in {'model', 'tokenizer', 'runner', 'import_kwargs'}:
raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
super().__setattr__(attr, value)
def __del__(self):
try:
del self.__llm_model__, self.__llm_tokenizer__, self.__llm_adapter_map__
except AttributeError:
pass
def __repr_args__(self):
yield from (
('model_id', self._model_id if not self._local else self.tag.name),
('revision', self._revision if self._revision else self.tag.version),
('backend', self.__llm_backend__),
('type', self.llm_type),
)
def __repr__(self) -> str:
return f'{self.__class__.__name__} {orjson.dumps({k: v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}'
@property
def import_kwargs(self):
return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {
'padding_side': 'left',
'truncation_side': 'left',
}
@property
def trust_remote_code(self):
env = os.getenv('TRUST_REMOTE_CODE')
if env is not None:
return check_bool_env('TRUST_REMOTE_CODE', env)
return self.__llm_trust_remote_code__
@property
def model_id(self):
return self._model_id
@property
def revision(self):
return self._revision
@property
def tag(self):
return self._tag
@property
def bentomodel(self):
return bentoml.models.get(self.tag)
@property
def quantization_config(self):
if self.__llm_quantization_config__ is None:
from ._quantisation import infer_quantisation_config
if self._quantization_config is not None:
self.__llm_quantization_config__ = self._quantization_config
elif self._quantise is not None:
self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(
self, self._quantise, **self._model_attrs
)
else:
raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
return self.__llm_quantization_config__
@property
def has_adapters(self):
return self._adapter_map is not None
@property
def local(self):
return self._local
@property
def quantise(self):
return self._quantise
@property
def llm_type(self):
return normalise_model_name(self._model_id)
@property
def llm_parameters(self):
return (self._model_decls, self._model_attrs), self._tokenizer_attrs
@property
def identifying_params(self):
return {
'configuration': self.config.model_dump_json(),
'model_ids': orjson.dumps(self.config['model_ids']).decode(),
'model_id': self.model_id,
}
@property
def tokenizer(self):
if self.__llm_tokenizer__ is None:
self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
return self.__llm_tokenizer__
@property
def runner(self):
from ._runners import runner
if self.__llm_runner__ is None:
self.__llm_runner__ = runner(self)
return self.__llm_runner__
def prepare(self, adapter_type='lora', use_gradient_checking=True, **attrs):
if self.__llm_backend__ != 'pt':
raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
from peft.mapping import get_peft_model
from peft.utils.other import prepare_model_for_kbit_training
model = get_peft_model(
prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking),
self.config['fine_tune_strategies']
.get(adapter_type, self.config.make_fine_tune_config(adapter_type))
.train()
.with_config(**attrs)
.build(),
)
if DEBUG:
model.print_trainable_parameters()
return model, self.tokenizer
def prepare_for_training(self, *args, **attrs):
logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Use `prepare` instead.')
return self.prepare(*args, **attrs)
@property
def adapter_map(self):
if not is_peft_available():
raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'")
if not self.has_adapters:
raise AttributeError('Adapter map is not available.')
assert self._adapter_map is not None
if self.__llm_adapter_map__ is None:
_map: ResolvedAdapterMap = {k: {} for k in self._adapter_map}
for adapter_type, adapter_tuple in self._adapter_map.items():
base = first_not_none(
self.config['fine_tune_strategies'].get(adapter_type),
default=self.config.make_fine_tune_config(adapter_type),
)
for adapter in adapter_tuple:
_map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
self.__llm_adapter_map__ = _map
return self.__llm_adapter_map__
@property
def model(self):
if self.__llm_model__ is None:
self.__llm_model__ = openllm.serialisation.load_model(self, *self._model_decls, **self._model_attrs)
return self.__llm_model__
@property
def config(self):
if self.__llm_config__ is None:
if self._local:
config_file = os.path.join(self.model_id, CONFIG_FILE_NAME)
else:
try:
config_file = self.bentomodel.path_of(CONFIG_FILE_NAME)
except OpenLLMException as err:
if not is_transformers_available():
raise MissingDependencyError(
"Requires 'transformers' to be available. Do 'pip install transformers'"
) from err
from transformers.utils import cached_file
try:
config_file = cached_file(self.model_id, CONFIG_FILE_NAME)
except Exception as err:
raise ValueError(
"Failed to determine architecture from 'config.json'. If this is a gated model, make sure to pass in HUGGING_FACE_HUB_TOKEN"
) from err
if not os.path.exists(config_file):
raise ValueError(f"Failed to find 'config.json' (config_json_path={config_file})")
with open(config_file, 'r', encoding='utf-8') as f:
loaded_config = orjson.loads(f.read())
if 'architectures' in loaded_config:
for architecture in loaded_config['architectures']:
if architecture in self._architecture_mappings:
self.__llm_config__ = openllm_core.AutoConfig.for_model(
self._architecture_mappings[architecture]
).model_construct_env()
break
else:
raise ValueError(f"Failed to find architecture from 'config.json' (config_json_path={config_file})")
return self.__llm_config__
@functools.lru_cache(maxsize=1)
def _torch_dtype_mapping() -> dict[str, torch.dtype]:
import torch
return {
'half': torch.float16,
'float16': torch.float16, #
'float': torch.float32,
'float32': torch.float32, #
'bfloat16': torch.bfloat16,
}
def convert_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
if not is_peft_available():
raise RuntimeError(
"LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'"
)
from huggingface_hub import hf_hub_download
resolved: AdapterMap = {}
for path_or_adapter_id, name in adapter_map.items():
if name is None:
raise ValueError('Adapter name must be specified.')
if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)):
config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)
else:
try:
config_file = hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME)
except Exception as err:
raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'") from err
with open(config_file, 'r') as file:
resolved_config = orjson.loads(file.read())
_peft_type = resolved_config['peft_type'].lower()
if _peft_type not in resolved:
resolved[_peft_type] = ()
resolved[_peft_type] += (_AdapterTuple((path_or_adapter_id, name, resolved_config)),)
return resolved