fix(yapf): align weird new lines break [generated] [skip ci] (#284)

fix(yapf): align weird new lines break

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-09-01 05:34:22 -04:00
committed by GitHub
parent 3e45530abd
commit b7af7765d4
91 changed files with 811 additions and 1678 deletions

View File

@@ -59,14 +59,12 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
return tokenizer
class _Caller(t.Protocol[P]):
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
...
_extras = ['get', 'import_model', 'load_model']
def _make_dispatch_function(fn: str) -> _Caller[P]:
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.

View File

@@ -7,6 +7,5 @@ FRAMEWORK_TO_AUTOCLASS_MAPPING = {
'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')
}
HUB_ATTRS = [
'cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision',
'subfolder', 'use_auth_token'
'cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision', 'subfolder', 'use_auth_token'
]

View File

@@ -13,11 +13,7 @@ if t.TYPE_CHECKING:
_conversion_strategy = {'pt': 'ggml'}
def import_model(llm: openllm.LLM[t.Any, t.Any],
*decls: t.Any,
trust_remote_code: bool = True,
**attrs: t.Any,
) -> bentoml.Model:
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
raise NotImplementedError('Currently work in progress.')
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:

View File

@@ -68,24 +68,18 @@ def import_model(llm: openllm.LLM[M, T],
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
_, tokenizer_attrs = llm.llm_parameters
quantize_method = llm._quantize_method
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'),
default=llm._serialisation_format == 'safetensors')
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation_format == 'safetensors')
# Disable safe serialization with vLLM
if llm.__llm_backend__ == 'vllm': safe_serialisation = False
metadata: DictStrAny = {
'safe_serialisation': safe_serialisation,
'_quantize': quantize_method is not None and quantize_method
}
metadata: DictStrAny = {'safe_serialisation': safe_serialisation, '_quantize': quantize_method is not None and quantize_method}
signatures: DictStrAny = {}
if quantize_method == 'gptq':
if not openllm.utils.is_autogptq_available():
raise openllm.exceptions.OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'"
)
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config['model_type'] != 'causal_lm':
raise openllm.exceptions.OpenLLMException(
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
signatures['generate'] = {'batchable': False}
else:
# this model might be called with --quantize int4, therefore we need to pop this out
@@ -95,10 +89,7 @@ def import_model(llm: openllm.LLM[M, T],
if llm.__llm_backend__ != 'flax': attrs['use_safetensors'] = safe_serialisation
metadata['_framework'] = 'pt' if llm.__llm_backend__ == 'vllm' else llm.__llm_backend__
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id,
trust_remote_code=trust_remote_code,
**hub_attrs,
**tokenizer_attrs)
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
@@ -117,25 +108,18 @@ def import_model(llm: openllm.LLM[M, T],
if quantize_method == 'gptq':
if not openllm.utils.is_autogptq_available():
raise openllm.exceptions.OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'"
)
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config['model_type'] != 'causal_lm':
raise openllm.exceptions.OpenLLMException(
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
logger.debug('Saving model with GPTQ quantisation will require loading model into memory.')
model = autogptq.AutoGPTQForCausalLM.from_quantized(llm.model_id,
*decls,
quantize_config=t.cast('autogptq.BaseQuantizeConfig',
llm.quantization_config),
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
trust_remote_code=trust_remote_code,
use_safetensors=safe_serialisation,
**hub_attrs,
**attrs)
update_model(bentomodel,
metadata={
'_pretrained_class': model.__class__.__name__,
'_framework': model.model.framework
})
update_model(bentomodel, metadata={'_pretrained_class': model.__class__.__name__, '_framework': model.model.framework})
model.save_quantized(bentomodel.path, use_safetensors=safe_serialisation)
else:
architectures = getattr(config, 'architectures', [])
@@ -159,18 +143,14 @@ def import_model(llm: openllm.LLM[M, T],
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
else:
# we will clone the all tings into the bentomodel path without loading model into memory
snapshot_download(llm.model_id,
local_dir=bentomodel.path,
local_dir_use_symlinks=False,
ignore_patterns=HfIgnore.ignore_patterns(llm))
snapshot_download(llm.model_id, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm))
except Exception:
raise
else:
bentomodel.flush() # type: ignore[no-untyped-call]
bentomodel.save(_model_store)
openllm.utils.analytics.track(
openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module,
model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024))
openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024))
finally:
bentomodel.exit_cloudpickle_context(imported_modules)
# NOTE: We need to free up the cache after importing the model
@@ -189,36 +169,29 @@ def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
try:
model = bentoml.models.get(llm.tag)
if Version(model.info.api_version) < Version('v2'):
raise openllm.exceptions.OpenLLMException(
'Please run "openllm prune -y --include-bentos" and upgrade all saved model to latest release.')
raise openllm.exceptions.OpenLLMException('Please run "openllm prune -y --include-bentos" and upgrade all saved model to latest release.')
if model.info.labels['backend'] != llm.__llm_backend__:
raise openllm.exceptions.OpenLLMException(
f"Model {model.tag} was saved with backend {model.info.labels['backend']}, while loading with {llm.__llm_backend__}."
)
f"Model {model.tag} was saved with backend {model.info.labels['backend']}, while loading with {llm.__llm_backend__}.")
return model
except Exception as err:
if auto_import: return import_model(llm, trust_remote_code=llm.trust_remote_code)
raise openllm.exceptions.OpenLLMException(
f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
raise openllm.exceptions.OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
config, hub_attrs, attrs = process_config(llm.model_id, llm.trust_remote_code, **attrs)
safe_serialization = openllm.utils.first_not_none(t.cast(
t.Optional[bool], llm._bentomodel.info.metadata.get('safe_serialisation', None)),
safe_serialization = openllm.utils.first_not_none(t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get('safe_serialisation', None)),
attrs.pop('safe_serialization', None),
default=llm._serialisation_format == 'safetensors')
if '_quantize' in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata['_quantize'] == 'gptq':
if not openllm.utils.is_autogptq_available():
raise openllm.exceptions.OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'"
)
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config['model_type'] != 'causal_lm':
raise openllm.exceptions.OpenLLMException(
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
return autogptq.AutoGPTQForCausalLM.from_quantized(llm._bentomodel.path,
*decls,
quantize_config=t.cast('autogptq.BaseQuantizeConfig',
llm.quantization_config),
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
trust_remote_code=llm.trust_remote_code,
use_safetensors=safe_serialization,
**hub_attrs,

View File

@@ -24,13 +24,11 @@ if t.TYPE_CHECKING:
from openllm_core._typing_compat import T
else:
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(),
'transformers'), openllm_core.utils.LazyLoader(
'torch', globals(), 'torch')
'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
_object_setattr = object.__setattr__
def process_config(model_id: str, trust_remote_code: bool,
**attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
Args:
@@ -55,8 +53,7 @@ def process_config(model_id: str, trust_remote_code: bool,
return config, hub_attrs, attrs
def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
__cls = getattr(transformers,
openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
if __cls is None:
raise ValueError(f'Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`')
return __cls
@@ -105,13 +102,11 @@ def make_model_signatures(llm: openllm.LLM[M, T]) -> ModelSignaturesType:
infer_fn: tuple[str, ...] = ('__call__',)
default_config = ModelSignature(batchable=False)
if llm.__llm_backend__ in {'pt', 'vllm'}:
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample',
'group_beam_search', 'constrained_beam_search',
)
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search',
'constrained_beam_search',
)
elif llm.__llm_backend__ == 'tf':
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search',
'contrastive_search',
)
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search', 'contrastive_search',)
else:
infer_fn += ('generate',)
return {k: default_config for k in infer_fn}