diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index f9348314..07962c7e 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -406,8 +406,8 @@ class LLM(t.Generic[M, T]): @functools.lru_cache(maxsize=1) def _torch_dtype_mapping() -> dict[str, torch.dtype]: import torch; return { - 'half': torch.float16, 'float16': torch.float16, - 'float': torch.float32, 'float32': torch.float32, + 'half': torch.float16, 'float16': torch.float16, # + 'float': torch.float32, 'float32': torch.float32, # 'bfloat16': torch.bfloat16, } def normalise_model_name(name: str) -> str: return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else inflection.dasherize(name.replace('/', '--'))