fix: loading correct local models (#599)

* fix(model): loading local correctly

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: update repr and correct bentomodel processor

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* ci: auto fixes from pre-commit.ci

For more information, see https://pre-commit.ci

* chore: cleanup transformers implementation

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* fix: ruff to ignore I001 on all stubs

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-10 02:36:12 -05:00
committed by GitHub
parent 5e45245457
commit fa2038f4e2
11 changed files with 121 additions and 111 deletions

View File

@@ -121,7 +121,7 @@ _AdapterTuple: type[AdapterTuple] = codegen.make_attr_tuple_class('AdapterTuple'
@attr.define(slots=True, repr=False, init=False)
class LLM(t.Generic[M, T]):
class LLM(t.Generic[M, T], ReprMixin):
_model_id: str
_revision: str | None
_quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None
@@ -136,6 +136,7 @@ class LLM(t.Generic[M, T]):
_prompt_template: PromptTemplate | None
_system_message: str | None
_bentomodel: bentoml.Model = attr.field(init=False)
__llm_config__: LLMConfig | None = None
__llm_backend__: LiteralBackend = None # type: ignore
__llm_quantization_config__: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None
@@ -229,6 +230,7 @@ class LLM(t.Generic[M, T]):
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
# resolve the tag
self._tag = model.tag
self._bentomodel = model
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
def _make_tag_components(
@@ -256,7 +258,14 @@ class LLM(t.Generic[M, T]):
if attr in _reserved_namespace:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
super().__setattr__(attr,value)
@property
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
def __repr_args__(self) -> ReprArgs:
yield 'model_id', self._model_id if not self._local else self.tag.name
yield 'revision', self._revision if self._revision else self.tag.version
yield 'backend', self.__llm_backend__
yield 'type', self.llm_type
@property
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
@property
def trust_remote_code(self)->bool:return first_not_none(check_bool_env('TRUST_REMOTE_CODE',False),default=self.__llm_trust_remote_code__)
@property
@@ -268,7 +277,7 @@ class LLM(t.Generic[M, T]):
@property
def tag(self)->bentoml.Tag:return self._tag
@property
def bentomodel(self)->bentoml.Model:return openllm.serialisation.get(self)
def bentomodel(self)->bentoml.Model:return self._bentomodel
@property
def config(self)->LLMConfig:
if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
@@ -282,6 +291,8 @@ class LLM(t.Generic[M, T]):
return self.__llm_quantization_config__
@property
def has_adapters(self)->bool:return self._adapter_map is not None
@property
def local(self)->bool:return self._local
# NOTE: The section below defines a loose contract with langchain's LLM interface.
@property
def llm_type(self)->str:return normalise_model_name(self._model_id)