fix(chatglm): support MacOS deployment

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-05-24 05:05:48 -07:00
parent 9139360426
commit 135bafacaf

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations
import platform
import typing as t
import bentoml
@@ -39,7 +40,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
class ChatGLM(openllm.LLM, _internal=True):
default_model = "THUDM/chatglm-6b"
default_model = "THUDM/chatglm-6b-int4"
variants = ["THUDM/chatglm-6b", "THUDM/chatglm-6b-int8", "THUDM/chatglm-6b-int4"]
@@ -69,10 +70,16 @@ class ChatGLM(openllm.LLM, _internal=True):
temperature: float | None = None,
**kwargs: t.Any,
) -> t.Any:
if torch.cuda.is_available():
self.model = self.model.cuda()
if self.config.use_half_precision:
self.model = self.model.half()
if torch.cuda.is_available():
self.model = self.model.cuda()
else:
self.model = self.model.float()
if platform.system() == "Darwin":
self.model = self.model.to("mps")
self.model.eval()
logit_processor = LogitsProcessorList()