mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-23 08:28:24 -04:00
perf: improve build logics and cleanup speed (#657)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -4,8 +4,8 @@ import importlib
|
||||
import cloudpickle
|
||||
import fs
|
||||
|
||||
import openllm
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
@@ -31,7 +31,7 @@ def load_tokenizer(llm, **tokenizer_attrs):
|
||||
try:
|
||||
tokenizer = cloudpickle.load(cofile)['tokenizer']
|
||||
except KeyError:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
raise OpenLLMException(
|
||||
"Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
|
||||
'For example: "bentoml.transformers.save_model(..., custom_objects={\'tokenizer\': tokenizer})"'
|
||||
) from None
|
||||
|
||||
@@ -13,7 +13,7 @@ import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelOptions, ModelSignature
|
||||
from openllm_core._typing_compat import M, T
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
|
||||
from ._helpers import get_hash, infer_autoclass_from_llm, process_config
|
||||
from .weights import HfIgnore
|
||||
@@ -24,7 +24,7 @@ __all__ = ['import_model', 'get', 'load_model']
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None):
|
||||
def _patch_correct_tag(llm, config, _revision=None):
|
||||
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
|
||||
if llm.revision is not None:
|
||||
return
|
||||
@@ -76,7 +76,7 @@ def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLCon
|
||||
|
||||
if quantize == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
raise OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
signatures['generate'] = {'batchable': False}
|
||||
@@ -175,7 +175,7 @@ def get(llm):
|
||||
model = bentoml.models.get(llm.tag)
|
||||
backend = model.info.labels['backend']
|
||||
if backend != llm.__llm_backend__:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
raise OpenLLMException(
|
||||
f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'."
|
||||
)
|
||||
_patch_correct_tag(
|
||||
@@ -185,9 +185,7 @@ def get(llm):
|
||||
)
|
||||
return model
|
||||
except Exception as err:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed while getting stored artefact (lookup for traceback):\n{err}'
|
||||
) from err
|
||||
raise OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
|
||||
|
||||
|
||||
def check_unintialised_params(model):
|
||||
@@ -216,13 +214,11 @@ def load_model(llm, *decls, **attrs):
|
||||
_quantise = llm.bentomodel.info.metadata['_quantize']
|
||||
if _quantise == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
raise OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})"
|
||||
)
|
||||
raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
|
||||
# TODO: investigate load with flash attention
|
||||
model = auto_class.from_pretrained(
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
import typing as t
|
||||
|
||||
import transformers
|
||||
|
||||
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING, HUB_ATTRS
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from ..._llm import LLM
|
||||
|
||||
|
||||
def get_hash(config) -> str:
|
||||
_commit_hash = getattr(config, '_commit_hash', None)
|
||||
@@ -34,7 +28,7 @@ def process_config(model_id, trust_remote_code, **attrs):
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
|
||||
def infer_autoclass_from_llm(llm: LLM[M, T], config, /):
|
||||
def infer_autoclass_from_llm(llm, config, /):
|
||||
if llm.trust_remote_code:
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if not hasattr(config, 'auto_map'):
|
||||
|
||||
@@ -13,7 +13,6 @@ if t.TYPE_CHECKING:
|
||||
from huggingface_hub.hf_api import ModelInfo as HfModelInfo
|
||||
|
||||
import openllm
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
__global_inst__ = None
|
||||
__cached_id__: dict[str, HfModelInfo] = dict()
|
||||
@@ -39,7 +38,7 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
|
||||
if validate_is_path(model_id):
|
||||
return next((True for item in Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
|
||||
return next((True for _ in Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
|
||||
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)
|
||||
|
||||
|
||||
@@ -52,7 +51,7 @@ class HfIgnore:
|
||||
gguf = '*.gguf'
|
||||
|
||||
@classmethod
|
||||
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
|
||||
def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]:
|
||||
if llm.__llm_backend__ in {'vllm', 'pt'}:
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id):
|
||||
|
||||
Reference in New Issue
Block a user