mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-26 16:27:51 -05:00
* wip Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> * feat: GPTQ transformers integration Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> * fix: only load if variable is available and add changelog Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> * chore: remove boilerplate check Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
101 lines
3.8 KiB
Python
101 lines
3.8 KiB
Python
'''Serialisation utilities for OpenLLM.
|
|
|
|
Currently supports transformers for PyTorch, Tensorflow and Flax.
|
|
|
|
Currently, GGML format is working in progress.
|
|
'''
|
|
from __future__ import annotations
|
|
import importlib
|
|
import typing as t
|
|
|
|
import cloudpickle
|
|
import fs
|
|
|
|
import openllm
|
|
|
|
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
|
from openllm_core._typing_compat import M
|
|
from openllm_core._typing_compat import ParamSpec
|
|
from openllm_core._typing_compat import T
|
|
|
|
if t.TYPE_CHECKING:
|
|
import bentoml
|
|
|
|
from . import constants as constants
|
|
from . import ggml as ggml
|
|
from . import transformers as transformers
|
|
|
|
P = ParamSpec('P')
|
|
|
|
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
|
'''Load the tokenizer from BentoML store.
|
|
|
|
By default, it will try to find the bentomodel whether it is in store..
|
|
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
|
'''
|
|
from .transformers._helpers import infer_tokenizers_from_llm
|
|
from .transformers._helpers import process_config
|
|
|
|
config, *_ = process_config(llm._bentomodel.path, llm.trust_remote_code)
|
|
|
|
bentomodel_fs = fs.open_fs(llm._bentomodel.path)
|
|
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
|
|
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, 'rb') as cofile:
|
|
try:
|
|
tokenizer = cloudpickle.load(t.cast('t.IO[bytes]', cofile))['tokenizer']
|
|
except KeyError:
|
|
raise openllm.exceptions.OpenLLMException("Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
|
|
"For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\"") from None
|
|
else:
|
|
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs)
|
|
|
|
if tokenizer.pad_token_id is None:
|
|
if config.pad_token_id is not None: tokenizer.pad_token_id = config.pad_token_id
|
|
elif config.eos_token_id is not None: tokenizer.pad_token_id = config.eos_token_id
|
|
elif tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
else: tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
|
return tokenizer
|
|
|
|
class _Caller(t.Protocol[P]):
|
|
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
|
...
|
|
|
|
_extras = ['get', 'import_model', 'load_model']
|
|
|
|
def _make_dispatch_function(fn: str) -> _Caller[P]:
|
|
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
|
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
|
|
|
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "tf", "flax", "vllm")'
|
|
|
|
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.__llm_backend__="ggml"'
|
|
"""
|
|
serde = 'transformers'
|
|
if llm.__llm_backend__ == 'ggml': serde = 'ggml'
|
|
return getattr(importlib.import_module(f'.{serde}', __name__), fn)(llm, *args, **kwargs)
|
|
|
|
return caller
|
|
|
|
if t.TYPE_CHECKING:
|
|
|
|
def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model:
|
|
...
|
|
|
|
def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model:
|
|
...
|
|
|
|
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M:
|
|
...
|
|
|
|
_import_structure: dict[str, list[str]] = {'ggml': [], 'transformers': [], 'constants': []}
|
|
__all__ = ['ggml', 'transformers', 'constants', 'load_tokenizer', *_extras]
|
|
|
|
def __dir__() -> list[str]:
|
|
return sorted(__all__)
|
|
|
|
def __getattr__(name: str) -> t.Any:
|
|
if name == 'load_tokenizer': return load_tokenizer
|
|
elif name in _import_structure: return importlib.import_module(f'.{name}', __name__)
|
|
elif name in _extras: return _make_dispatch_function(name)
|
|
else: raise AttributeError(f'{__name__} has no attribute {name}')
|