fix: loading correct local models (#599)

* fix(model): loading local correctly

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: update repr and correct bentomodel processor

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* ci: auto fixes from pre-commit.ci

For more information, see https://pre-commit.ci

* chore: cleanup transformers implementation

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* fix: ruff to ignore I001 on all stubs

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-10 02:36:12 -05:00
committed by GitHub
parent 5e45245457
commit fa2038f4e2
11 changed files with 121 additions and 111 deletions

View File

@@ -24,7 +24,7 @@ repos:
language: system
always_run: true
pass_filenames: false
entry: mypy --strict openllm-client/src/openllm_client/__init__.pyi openllm-core/src/openllm_core/_typing_compat.py
entry: mypy --strict openllm-client/src/openllm_client/__init__.pyi openllm-core/src/openllm_core/_typing_compat.py openllm-python/src/openllm/serialisation/__init__.pyi
- repo: https://github.com/editorconfig-checker/editorconfig-checker.python
rev: '2.7.3'
hooks:

View File

@@ -121,7 +121,7 @@ _AdapterTuple: type[AdapterTuple] = codegen.make_attr_tuple_class('AdapterTuple'
@attr.define(slots=True, repr=False, init=False)
class LLM(t.Generic[M, T]):
class LLM(t.Generic[M, T], ReprMixin):
_model_id: str
_revision: str | None
_quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None
@@ -136,6 +136,7 @@ class LLM(t.Generic[M, T]):
_prompt_template: PromptTemplate | None
_system_message: str | None
_bentomodel: bentoml.Model = attr.field(init=False)
__llm_config__: LLMConfig | None = None
__llm_backend__: LiteralBackend = None # type: ignore
__llm_quantization_config__: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None
@@ -229,6 +230,7 @@ class LLM(t.Generic[M, T]):
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
# resolve the tag
self._tag = model.tag
self._bentomodel = model
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
def _make_tag_components(
@@ -256,7 +258,14 @@ class LLM(t.Generic[M, T]):
if attr in _reserved_namespace:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
super().__setattr__(attr,value)
@property
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
def __repr_args__(self) -> ReprArgs:
yield 'model_id', self._model_id if not self._local else self.tag.name
yield 'revision', self._revision if self._revision else self.tag.version
yield 'backend', self.__llm_backend__
yield 'type', self.llm_type
@property
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
@property
def trust_remote_code(self)->bool:return first_not_none(check_bool_env('TRUST_REMOTE_CODE',False),default=self.__llm_trust_remote_code__)
@property
@@ -268,7 +277,7 @@ class LLM(t.Generic[M, T]):
@property
def tag(self)->bentoml.Tag:return self._tag
@property
def bentomodel(self)->bentoml.Model:return openllm.serialisation.get(self)
def bentomodel(self)->bentoml.Model:return self._bentomodel
@property
def config(self)->LLMConfig:
if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
@@ -282,6 +291,8 @@ class LLM(t.Generic[M, T]):
return self.__llm_quantization_config__
@property
def has_adapters(self)->bool:return self._adapter_map is not None
@property
def local(self)->bool:return self._local
# NOTE: The section below defines a loose contract with langchain's LLM interface.
@property
def llm_type(self)->str:return normalise_model_name(self._model_id)

View File

@@ -78,6 +78,8 @@ def parse_config_options(
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
_bentoml_config_options_opts = [
'tracing.sample_rate=1.0',
'api_server.max_runner_connections=25',
f'runners."llm-{config["start_name"]}-runner".batching.max_batch_size=128',
f'api_server.traffic.timeout={server_timeout}',
f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}',
f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}',

View File

@@ -199,7 +199,9 @@ def _build(
'--serialisation',
t.cast(
LiteralSerialisation,
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
),
]
if quantize:

View File

@@ -420,7 +420,9 @@ def start_command(
serialisation = t.cast(
LiteralSerialisation,
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
)
if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
termui.warning(
@@ -592,7 +594,9 @@ def start_grpc_command(
serialisation = t.cast(
LiteralSerialisation,
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
)
if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
termui.warning(
@@ -789,7 +793,9 @@ def import_command(
backend=backend,
serialisation=t.cast(
LiteralSerialisation,
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
),
)
backend_warning(llm.__llm_backend__)
@@ -963,7 +969,9 @@ def build_command(
quantize=quantize,
serialisation=t.cast(
LiteralSerialisation,
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
),
)
backend_warning(llm.__llm_backend__)

View File

@@ -1,10 +1,3 @@
"""Serialisation utilities for OpenLLM.
Currently supports transformers for PyTorch, and vLLM.
Currently, GGML format is working in progress.
"""
from __future__ import annotations
import importlib
import typing as t
@@ -14,32 +7,26 @@ import fs
import openllm
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
from openllm_core._typing_compat import M
from openllm_core._typing_compat import ParamSpec
from openllm_core._typing_compat import T
if t.TYPE_CHECKING:
import transformers as _transformers
import bentoml
from . import constants as constants
from . import ggml as ggml
from . import transformers as transformers
import transformers as transformers
else:
_transformers = openllm.utils.LazyLoader('_transformers', globals(), 'transformers')
transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
P = ParamSpec('P')
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
def load_tokenizer(llm, **tokenizer_attrs):
"""Load the tokenizer from BentoML store.
By default, it will try to find the bentomodel whether it is in store..
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
"""
tokenizer_attrs = {**llm.llm_parameters[-1], **tokenizer_attrs}
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
from .transformers._helpers import process_config
config, *_ = process_config(llm.bentomodel.path, llm.trust_remote_code)
@@ -48,14 +35,14 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, 'rb') as cofile:
try:
tokenizer = cloudpickle.load(t.cast('t.IO[bytes]', cofile))['tokenizer']
tokenizer = cloudpickle.load(cofile)['tokenizer']
except KeyError:
raise openllm.exceptions.OpenLLMException(
"Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
'For example: "bentoml.transformers.save_model(..., custom_objects={\'tokenizer\': tokenizer})"'
) from None
else:
tokenizer = _transformers.AutoTokenizer.from_pretrained(
tokenizer = transformers.AutoTokenizer.from_pretrained(
bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs
)
@@ -71,15 +58,11 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
return tokenizer
class _Caller(t.Protocol[P]):
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any: ...
_extras = ['get', 'import_model', 'load_model']
def _make_dispatch_function(fn: str) -> _Caller[P]:
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
def _make_dispatch_function(fn):
def caller(llm, *args, **kwargs):
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")'
@@ -94,24 +77,15 @@ def _make_dispatch_function(fn: str) -> _Caller[P]:
return caller
if t.TYPE_CHECKING:
def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M: ...
_import_structure: dict[str, list[str]] = {'ggml': [], 'transformers': [], 'constants': []}
__all__ = ['ggml', 'transformers', 'constants', 'load_tokenizer', *_extras]
def __dir__() -> list[str]:
def __dir__():
return sorted(__all__)
def __getattr__(name: str) -> t.Any:
def __getattr__(name):
if name == 'load_tokenizer':
return load_tokenizer
elif name in _import_structure:

View File

@@ -0,0 +1,22 @@
"""Serialisation utilities for OpenLLM.
Currently supports transformers for PyTorch, and vLLM.
Currently, GGML format is working in progress.
"""
from typing import Any
from bentoml import Model
from openllm import LLM
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
from . import constants as constants
from . import ggml as ggml
from . import transformers as transformers
def load_tokenizer(llm: LLM[M, T], **attrs: Any) -> T: ...
def get(llm: LLM[M, T]) -> Model: ...
def import_model(llm: LLM[M, T], *args: Any, trust_remote_code: bool, **attrs: Any) -> Model: ...
def load_model(llm: LLM[M, T], *args: Any, **attrs: Any) -> M: ...

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import importlib
import logging
import typing as t
import attr
import orjson
@@ -29,25 +28,17 @@ from ._helpers import process_config
from .weights import HfIgnore
if t.TYPE_CHECKING:
import types
from bentoml._internal.models import ModelStore
from openllm_core._typing_compat import DictStrAny
logger = logging.getLogger(__name__)
__all__ = ['import_model', 'get', 'load_model']
_object_setattr = object.__setattr__
def _patch_correct_tag(
llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None
) -> None:
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None):
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
if llm.revision is not None:
return
if not llm._local:
if not llm.local:
try:
if _revision is None:
_revision = get_hash(config)
@@ -55,7 +46,7 @@ def _patch_correct_tag(
pass
if _revision is None and llm.tag.version is not None:
_revision = llm.tag.version
if llm._tag.version is None:
if llm.tag.version is None:
_object_setattr(
llm, '_tag', attr.evolve(llm.tag, version=_revision)
) # HACK: This copies the correct revision into llm.tag
@@ -64,13 +55,12 @@ def _patch_correct_tag(
@inject
def import_model(
llm: openllm.LLM[M, T],
*decls: t.Any,
trust_remote_code: bool,
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
**attrs: t.Any,
) -> bentoml.Model:
def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLContainer.model_store], **attrs):
_base_decls, _base_attrs = llm.llm_parameters[0]
decls = (*_base_decls, *decls)
attrs = {**_base_attrs, **attrs}
if llm._local:
logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
_patch_correct_tag(llm, config)
_, tokenizer_attrs = llm.llm_parameters
@@ -78,7 +68,7 @@ def import_model(
safe_serialisation = openllm.utils.first_not_none(
attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors'
)
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
metadata = {'safe_serialisation': safe_serialisation}
if quantize:
metadata['_quantize'] = quantize
architectures = getattr(config, 'architectures', [])
@@ -87,9 +77,12 @@ def import_model(
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
)
metadata['_pretrained_class'] = architectures[0]
metadata['_revision'] = get_hash(config)
if not llm._local:
metadata['_revision'] = get_hash(config)
else:
metadata['_revision'] = llm.revision
signatures: DictStrAny = {}
signatures = {}
if quantize == 'gptq':
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
@@ -125,8 +118,8 @@ def import_model(
tokenizer.pad_token = tokenizer.eos_token
model = None
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
imported_modules: list[types.ModuleType] = []
external_modules = [importlib.import_module(tokenizer.__module__)]
imported_modules = []
bentomodel = bentoml.Model.create(
llm.tag,
module='openllm.serialisation.transformers',
@@ -150,11 +143,20 @@ def import_model(
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
if llm._local: # possible local path
model = infer_autoclass_from_llm(llm, config).from_pretrained(
llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs
llm.model_id,
*decls,
local_files_only=True,
config=config,
trust_remote_code=trust_remote_code,
**hub_attrs,
**attrs,
)
# for trust_remote_code to work
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# we will clone the all tings into the bentomodel path without loading model into memory
snapshot_download(
@@ -175,15 +177,10 @@ def import_model(
)
finally:
bentomodel.exit_cloudpickle_context(imported_modules)
# NOTE: We need to free up the cache after importing the model
# in the case where users first run openllm start without the model available locally.
if openllm.utils.is_torch_available() and torch.cuda.is_available():
torch.cuda.empty_cache()
del model
return bentomodel
def get(llm: openllm.LLM[M, T]) -> bentoml.Model:
def get(llm):
try:
model = bentoml.models.get(llm.tag)
backend = model.info.labels['backend']
@@ -203,7 +200,13 @@ def get(llm: openllm.LLM[M, T]) -> bentoml.Model:
) from err
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
def check_unintialised_params(model):
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
if len(unintialized) > 0:
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
def load_model(llm, *decls, **attrs):
if llm._quantise in {'awq', 'squeezellm'}:
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
config, attrs = transformers.AutoConfig.from_pretrained(
@@ -232,7 +235,9 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
)
# TODO: investigate load with flash attention
model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, **attrs)
model = auto_class.from_pretrained(
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
)
else:
model = auto_class.from_pretrained(
llm.bentomodel.path,
@@ -242,4 +247,5 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
device_map=device_map,
**attrs,
)
return t.cast('M', model)
check_unintialised_params(model)
return model

View File

@@ -2,43 +2,28 @@ from __future__ import annotations
import copy
import typing as t
import torch
import transformers
import openllm
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
from openllm.serialisation.constants import HUB_ATTRS
from openllm_core.exceptions import OpenLLMException
if t.TYPE_CHECKING:
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from openllm_core._typing_compat import DictStrAny
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
from ..._llm import LLM
def get_hash(config: transformers.PretrainedConfig) -> str:
def get_hash(config) -> str:
_commit_hash = getattr(config, '_commit_hash', None)
if _commit_hash is None:
raise ValueError(f'Cannot find commit hash in {config}')
return _commit_hash
def process_config(
model_id: str, trust_remote_code: bool, **attrs: t.Any
) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
"""A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
Args:
model_id: Model id to pass into ``transformers.AutoConfig``.
trust_remote_code: Whether to trust_remote_code or not
attrs: All possible attributes that can be processed by ``transformers.AutoConfig.from_pretrained
Returns:
A tuple of ``transformers.PretrainedConfig``, all hub attributes, and remanining attributes that can be used by the Model class.
"""
def process_config(model_id, trust_remote_code, **attrs):
config = attrs.pop('config', None)
# this logic below is synonymous to handling `from_pretrained` attrs.
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
@@ -52,7 +37,7 @@ def process_config(
return config, hub_attrs, attrs
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
def infer_autoclass_from_llm(llm: LLM[M, T], config, /):
if llm.trust_remote_code:
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
if not hasattr(config, 'auto_map'):
@@ -70,11 +55,5 @@ def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.Pretra
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING:
idx = 1
else:
raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
raise OpenLLMException(f'Model type {type(config)} is not supported yet.')
return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_backend__][idx])
def check_unintialised_params(model: torch.nn.Module) -> None:
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
if len(unintialized) > 0:
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')

View File

@@ -2,11 +2,15 @@ from __future__ import annotations
import traceback
import typing as t
from pathlib import Path
import attr
from huggingface_hub import HfApi
from openllm_core.exceptions import Error
from openllm_core.utils import resolve_filepath
from openllm_core.utils import validate_is_path
if t.TYPE_CHECKING:
@@ -40,6 +44,8 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
if validate_is_path(model_id):
return next((True for item in Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)

View File

@@ -93,7 +93,7 @@ multiline-quotes = "double"
docstring-quotes = "double"
[lint.extend-per-file-ignores]
"openllm-client/src/openllm_client/__init__.pyi" = ["I001"]
"/**/*.pyi" = ["I001"]
"openllm-python/tests/**/*" = ["S101", "TID252", "PT011", "S307"]
"openllm-python/src/openllm/_llm.py" = ["F811"]
"openllm-core/src/openllm_core/utils/import_utils.py" = ["PLW0603", "F811"]