From 135bafacaf2f7c0f40efcda76f58ad55e9692c30 Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Wed, 24 May 2023 05:05:48 -0700 Subject: [PATCH] fix(chatglm): support MacOS deployment Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- src/openllm/models/chatglm/modeling_chatglm.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/openllm/models/chatglm/modeling_chatglm.py b/src/openllm/models/chatglm/modeling_chatglm.py index db58a551..19ba4022 100644 --- a/src/openllm/models/chatglm/modeling_chatglm.py +++ b/src/openllm/models/chatglm/modeling_chatglm.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import platform import typing as t import bentoml @@ -39,7 +40,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor): class ChatGLM(openllm.LLM, _internal=True): - default_model = "THUDM/chatglm-6b" + default_model = "THUDM/chatglm-6b-int4" variants = ["THUDM/chatglm-6b", "THUDM/chatglm-6b-int8", "THUDM/chatglm-6b-int4"] @@ -69,10 +70,16 @@ class ChatGLM(openllm.LLM, _internal=True): temperature: float | None = None, **kwargs: t.Any, ) -> t.Any: - if torch.cuda.is_available(): - self.model = self.model.cuda() if self.config.use_half_precision: self.model = self.model.half() + if torch.cuda.is_available(): + self.model = self.model.cuda() + else: + self.model = self.model.float() + + if platform.system() == "Darwin": + self.model = self.model.to("mps") + self.model.eval() logit_processor = LogitsProcessorList()