mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-27 16:59:13 -05:00
fix(yapf): align weird new lines break [generated] [skip ci] (#284)
fix(yapf): align weird new lines break Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -59,14 +59,12 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
return tokenizer
|
||||
|
||||
class _Caller(t.Protocol[P]):
|
||||
|
||||
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
...
|
||||
|
||||
_extras = ['get', 'import_model', 'load_model']
|
||||
|
||||
def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
|
||||
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
|
||||
|
||||
@@ -7,6 +7,5 @@ FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')
|
||||
}
|
||||
HUB_ATTRS = [
|
||||
'cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision',
|
||||
'subfolder', 'use_auth_token'
|
||||
'cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision', 'subfolder', 'use_auth_token'
|
||||
]
|
||||
|
||||
@@ -13,11 +13,7 @@ if t.TYPE_CHECKING:
|
||||
|
||||
_conversion_strategy = {'pt': 'ggml'}
|
||||
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any],
|
||||
*decls: t.Any,
|
||||
trust_remote_code: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Model:
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
|
||||
|
||||
@@ -68,24 +68,18 @@ def import_model(llm: openllm.LLM[M, T],
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize_method = llm._quantize_method
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'),
|
||||
default=llm._serialisation_format == 'safetensors')
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation_format == 'safetensors')
|
||||
# Disable safe serialization with vLLM
|
||||
if llm.__llm_backend__ == 'vllm': safe_serialisation = False
|
||||
metadata: DictStrAny = {
|
||||
'safe_serialisation': safe_serialisation,
|
||||
'_quantize': quantize_method is not None and quantize_method
|
||||
}
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation, '_quantize': quantize_method is not None and quantize_method}
|
||||
signatures: DictStrAny = {}
|
||||
|
||||
if quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'"
|
||||
)
|
||||
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
signatures['generate'] = {'batchable': False}
|
||||
else:
|
||||
# this model might be called with --quantize int4, therefore we need to pop this out
|
||||
@@ -95,10 +89,7 @@ def import_model(llm: openllm.LLM[M, T],
|
||||
if llm.__llm_backend__ != 'flax': attrs['use_safetensors'] = safe_serialisation
|
||||
metadata['_framework'] = 'pt' if llm.__llm_backend__ == 'vllm' else llm.__llm_backend__
|
||||
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**tokenizer_attrs)
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
|
||||
@@ -117,25 +108,18 @@ def import_model(llm: openllm.LLM[M, T],
|
||||
if quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'"
|
||||
)
|
||||
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
logger.debug('Saving model with GPTQ quantisation will require loading model into memory.')
|
||||
model = autogptq.AutoGPTQForCausalLM.from_quantized(llm.model_id,
|
||||
*decls,
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig',
|
||||
llm.quantization_config),
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_safetensors=safe_serialisation,
|
||||
**hub_attrs,
|
||||
**attrs)
|
||||
update_model(bentomodel,
|
||||
metadata={
|
||||
'_pretrained_class': model.__class__.__name__,
|
||||
'_framework': model.model.framework
|
||||
})
|
||||
update_model(bentomodel, metadata={'_pretrained_class': model.__class__.__name__, '_framework': model.model.framework})
|
||||
model.save_quantized(bentomodel.path, use_safetensors=safe_serialisation)
|
||||
else:
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
@@ -159,18 +143,14 @@ def import_model(llm: openllm.LLM[M, T],
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(llm.model_id,
|
||||
local_dir=bentomodel.path,
|
||||
local_dir_use_symlinks=False,
|
||||
ignore_patterns=HfIgnore.ignore_patterns(llm))
|
||||
snapshot_download(llm.model_id, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm))
|
||||
except Exception:
|
||||
raise
|
||||
else:
|
||||
bentomodel.flush() # type: ignore[no-untyped-call]
|
||||
bentomodel.save(_model_store)
|
||||
openllm.utils.analytics.track(
|
||||
openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module,
|
||||
model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024))
|
||||
openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024))
|
||||
finally:
|
||||
bentomodel.exit_cloudpickle_context(imported_modules)
|
||||
# NOTE: We need to free up the cache after importing the model
|
||||
@@ -189,36 +169,29 @@ def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if Version(model.info.api_version) < Version('v2'):
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
'Please run "openllm prune -y --include-bentos" and upgrade all saved model to latest release.')
|
||||
raise openllm.exceptions.OpenLLMException('Please run "openllm prune -y --include-bentos" and upgrade all saved model to latest release.')
|
||||
if model.info.labels['backend'] != llm.__llm_backend__:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"Model {model.tag} was saved with backend {model.info.labels['backend']}, while loading with {llm.__llm_backend__}."
|
||||
)
|
||||
f"Model {model.tag} was saved with backend {model.info.labels['backend']}, while loading with {llm.__llm_backend__}.")
|
||||
return model
|
||||
except Exception as err:
|
||||
if auto_import: return import_model(llm, trust_remote_code=llm.trust_remote_code)
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, llm.trust_remote_code, **attrs)
|
||||
safe_serialization = openllm.utils.first_not_none(t.cast(
|
||||
t.Optional[bool], llm._bentomodel.info.metadata.get('safe_serialisation', None)),
|
||||
safe_serialization = openllm.utils.first_not_none(t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get('safe_serialisation', None)),
|
||||
attrs.pop('safe_serialization', None),
|
||||
default=llm._serialisation_format == 'safetensors')
|
||||
if '_quantize' in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata['_quantize'] == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'"
|
||||
)
|
||||
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
return autogptq.AutoGPTQForCausalLM.from_quantized(llm._bentomodel.path,
|
||||
*decls,
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig',
|
||||
llm.quantization_config),
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
use_safetensors=safe_serialization,
|
||||
**hub_attrs,
|
||||
|
||||
@@ -24,13 +24,11 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import T
|
||||
else:
|
||||
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(),
|
||||
'transformers'), openllm_core.utils.LazyLoader(
|
||||
'torch', globals(), 'torch')
|
||||
'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def process_config(model_id: str, trust_remote_code: bool,
|
||||
**attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
Args:
|
||||
@@ -55,8 +53,7 @@ def process_config(model_id: str, trust_remote_code: bool,
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
|
||||
__cls = getattr(transformers,
|
||||
openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
|
||||
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
|
||||
if __cls is None:
|
||||
raise ValueError(f'Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`')
|
||||
return __cls
|
||||
@@ -105,13 +102,11 @@ def make_model_signatures(llm: openllm.LLM[M, T]) -> ModelSignaturesType:
|
||||
infer_fn: tuple[str, ...] = ('__call__',)
|
||||
default_config = ModelSignature(batchable=False)
|
||||
if llm.__llm_backend__ in {'pt', 'vllm'}:
|
||||
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample',
|
||||
'group_beam_search', 'constrained_beam_search',
|
||||
)
|
||||
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search',
|
||||
'constrained_beam_search',
|
||||
)
|
||||
elif llm.__llm_backend__ == 'tf':
|
||||
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search',
|
||||
'contrastive_search',
|
||||
)
|
||||
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search', 'contrastive_search',)
|
||||
else:
|
||||
infer_fn += ('generate',)
|
||||
return {k: default_config for k in infer_fn}
|
||||
|
||||
Reference in New Issue
Block a user