fix(torch_dtype): load eagerly (#631)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-13 13:48:04 -05:00
committed by GitHub
parent 0924c0b34d
commit d358e68539

View File

@@ -249,11 +249,13 @@ class LLM(t.Generic[M, T], ReprMixin):
import transformers
if not isinstance(self.__llm_torch_dtype__, torch.dtype):
config_dtype = getattr(
transformers.AutoConfig.from_pretrained(self.bentomodel.path, trust_remote_code=self.trust_remote_code),
'torch_dtype',
None,
)
try:
hf_config = transformers.AutoConfig.from_pretrained(
self.bentomodel.path, trust_remote_code=self.trust_remote_code
)
except OpenLLMException:
hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
config_dtype = getattr(hf_config, 'torch_dtype', None)
if config_dtype is None:
config_dtype = torch.float32
if self.__llm_torch_dtype__ == 'auto':