diff --git a/src/openllm/models/dolly_v2/modeling_dolly_v2.py b/src/openllm/models/dolly_v2/modeling_dolly_v2.py index cf416012..3175d692 100644 --- a/src/openllm/models/dolly_v2/modeling_dolly_v2.py +++ b/src/openllm/models/dolly_v2/modeling_dolly_v2.py @@ -38,7 +38,7 @@ class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedToken @property def import_kwargs(self): model_kwds = { - "device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, + "device_map": "auto" if torch.cuda.is_available() else None, "torch_dtype": torch.bfloat16, } tokenizer_kwds = {"padding_side": "left"}