mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-06 13:52:21 -05:00
fix: loading correct local models (#599)
* fix(model): loading local correctly Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * chore: update repr and correct bentomodel processor Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci * chore: cleanup transformers implementation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * fix: ruff to ignore I001 on all stubs Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -121,7 +121,7 @@ _AdapterTuple: type[AdapterTuple] = codegen.make_attr_tuple_class('AdapterTuple'
|
||||
|
||||
|
||||
@attr.define(slots=True, repr=False, init=False)
|
||||
class LLM(t.Generic[M, T]):
|
||||
class LLM(t.Generic[M, T], ReprMixin):
|
||||
_model_id: str
|
||||
_revision: str | None
|
||||
_quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None
|
||||
@@ -136,6 +136,7 @@ class LLM(t.Generic[M, T]):
|
||||
_prompt_template: PromptTemplate | None
|
||||
_system_message: str | None
|
||||
|
||||
_bentomodel: bentoml.Model = attr.field(init=False)
|
||||
__llm_config__: LLMConfig | None = None
|
||||
__llm_backend__: LiteralBackend = None # type: ignore
|
||||
__llm_quantization_config__: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None
|
||||
@@ -229,6 +230,7 @@ class LLM(t.Generic[M, T]):
|
||||
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
|
||||
# resolve the tag
|
||||
self._tag = model.tag
|
||||
self._bentomodel = model
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(
|
||||
@@ -256,7 +258,14 @@ class LLM(t.Generic[M, T]):
|
||||
if attr in _reserved_namespace:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
super().__setattr__(attr,value)
|
||||
@property
|
||||
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
|
||||
def __repr_args__(self) -> ReprArgs:
|
||||
yield 'model_id', self._model_id if not self._local else self.tag.name
|
||||
yield 'revision', self._revision if self._revision else self.tag.version
|
||||
yield 'backend', self.__llm_backend__
|
||||
yield 'type', self.llm_type
|
||||
@property
|
||||
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
@property
|
||||
def trust_remote_code(self)->bool:return first_not_none(check_bool_env('TRUST_REMOTE_CODE',False),default=self.__llm_trust_remote_code__)
|
||||
@property
|
||||
@@ -268,7 +277,7 @@ class LLM(t.Generic[M, T]):
|
||||
@property
|
||||
def tag(self)->bentoml.Tag:return self._tag
|
||||
@property
|
||||
def bentomodel(self)->bentoml.Model:return openllm.serialisation.get(self)
|
||||
def bentomodel(self)->bentoml.Model:return self._bentomodel
|
||||
@property
|
||||
def config(self)->LLMConfig:
|
||||
if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
|
||||
@@ -282,6 +291,8 @@ class LLM(t.Generic[M, T]):
|
||||
return self.__llm_quantization_config__
|
||||
@property
|
||||
def has_adapters(self)->bool:return self._adapter_map is not None
|
||||
@property
|
||||
def local(self)->bool:return self._local
|
||||
# NOTE: The section below defines a loose contract with langchain's LLM interface.
|
||||
@property
|
||||
def llm_type(self)->str:return normalise_model_name(self._model_id)
|
||||
|
||||
Reference in New Issue
Block a user