diff --git a/src/openllm/models/chatglm/modeling_chatglm.py b/src/openllm/models/chatglm/modeling_chatglm.py index db58a551..19ba4022 100644 --- a/src/openllm/models/chatglm/modeling_chatglm.py +++ b/src/openllm/models/chatglm/modeling_chatglm.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import platform import typing as t import bentoml @@ -39,7 +40,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor): class ChatGLM(openllm.LLM, _internal=True): - default_model = "THUDM/chatglm-6b" + default_model = "THUDM/chatglm-6b-int4" variants = ["THUDM/chatglm-6b", "THUDM/chatglm-6b-int8", "THUDM/chatglm-6b-int4"] @@ -69,10 +70,16 @@ class ChatGLM(openllm.LLM, _internal=True): temperature: float | None = None, **kwargs: t.Any, ) -> t.Any: - if torch.cuda.is_available(): - self.model = self.model.cuda() if self.config.use_half_precision: self.model = self.model.half() + if torch.cuda.is_available(): + self.model = self.model.cuda() + else: + self.model = self.model.float() + + if platform.system() == "Darwin": + self.model = self.model.to("mps") + self.model.eval() logit_processor = LogitsProcessorList()