diff --git a/src/openllm/models/mpt/modeling_mpt.py b/src/openllm/models/mpt/modeling_mpt.py index d56b012d..975b1a32 100644 --- a/src/openllm/models/mpt/modeling_mpt.py +++ b/src/openllm/models/mpt/modeling_mpt.py @@ -111,7 +111,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> transformers.PreTrainedModel: torch_dtype = attrs.pop("torch_dtype", torch.float32) device_map = attrs.pop("device_map", None) - trust_remote_code = attrs.pop("trust_remote_code", False) + trust_remote_code = attrs.pop("trust_remote_code", True) _ref = bentoml.transformers.get(tag) config = get_mpt_config(