mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-19 15:18:12 -05:00
fix(torch_dtype): load eagerly (#631)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -249,11 +249,13 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
import transformers
|
||||
|
||||
if not isinstance(self.__llm_torch_dtype__, torch.dtype):
|
||||
config_dtype = getattr(
|
||||
transformers.AutoConfig.from_pretrained(self.bentomodel.path, trust_remote_code=self.trust_remote_code),
|
||||
'torch_dtype',
|
||||
None,
|
||||
)
|
||||
try:
|
||||
hf_config = transformers.AutoConfig.from_pretrained(
|
||||
self.bentomodel.path, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
except OpenLLMException:
|
||||
hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
|
||||
config_dtype = getattr(hf_config, 'torch_dtype', None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
if self.__llm_torch_dtype__ == 'auto':
|
||||
|
||||
Reference in New Issue
Block a user