mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-20 07:33:55 -05:00
fix: loading correct local models (#599)
* fix(model): loading local correctly Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * chore: update repr and correct bentomodel processor Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci * chore: cleanup transformers implementation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * fix: ruff to ignore I001 on all stubs Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,10 +1,3 @@
|
||||
"""Serialisation utilities for OpenLLM.
|
||||
|
||||
Currently supports transformers for PyTorch, and vLLM.
|
||||
|
||||
Currently, GGML format is working in progress.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import typing as t
|
||||
@@ -14,32 +7,26 @@ import fs
|
||||
|
||||
import openllm
|
||||
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers as _transformers
|
||||
|
||||
import bentoml
|
||||
|
||||
from . import constants as constants
|
||||
from . import ggml as ggml
|
||||
from . import transformers as transformers
|
||||
import transformers as transformers
|
||||
else:
|
||||
_transformers = openllm.utils.LazyLoader('_transformers', globals(), 'transformers')
|
||||
transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
def load_tokenizer(llm, **tokenizer_attrs):
|
||||
"""Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
tokenizer_attrs = {**llm.llm_parameters[-1], **tokenizer_attrs}
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
|
||||
from .transformers._helpers import process_config
|
||||
|
||||
config, *_ = process_config(llm.bentomodel.path, llm.trust_remote_code)
|
||||
@@ -48,14 +35,14 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
|
||||
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, 'rb') as cofile:
|
||||
try:
|
||||
tokenizer = cloudpickle.load(t.cast('t.IO[bytes]', cofile))['tokenizer']
|
||||
tokenizer = cloudpickle.load(cofile)['tokenizer']
|
||||
except KeyError:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
|
||||
'For example: "bentoml.transformers.save_model(..., custom_objects={\'tokenizer\': tokenizer})"'
|
||||
) from None
|
||||
else:
|
||||
tokenizer = _transformers.AutoTokenizer.from_pretrained(
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs
|
||||
)
|
||||
|
||||
@@ -71,15 +58,11 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
return tokenizer
|
||||
|
||||
|
||||
class _Caller(t.Protocol[P]):
|
||||
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any: ...
|
||||
|
||||
|
||||
_extras = ['get', 'import_model', 'load_model']
|
||||
|
||||
|
||||
def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
def _make_dispatch_function(fn):
|
||||
def caller(llm, *args, **kwargs):
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")'
|
||||
@@ -94,24 +77,15 @@ def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
return caller
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
|
||||
|
||||
def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M: ...
|
||||
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'ggml': [], 'transformers': [], 'constants': []}
|
||||
__all__ = ['ggml', 'transformers', 'constants', 'load_tokenizer', *_extras]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
def __dir__():
|
||||
return sorted(__all__)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
def __getattr__(name):
|
||||
if name == 'load_tokenizer':
|
||||
return load_tokenizer
|
||||
elif name in _import_structure:
|
||||
|
||||
Reference in New Issue
Block a user