mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-05-05 14:22:43 -04:00
fix(chatglm): support MacOS deployment
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
@@ -39,7 +40,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
|
||||
|
||||
class ChatGLM(openllm.LLM, _internal=True):
|
||||
default_model = "THUDM/chatglm-6b"
|
||||
default_model = "THUDM/chatglm-6b-int4"
|
||||
|
||||
variants = ["THUDM/chatglm-6b", "THUDM/chatglm-6b-int8", "THUDM/chatglm-6b-int4"]
|
||||
|
||||
@@ -69,10 +70,16 @@ class ChatGLM(openllm.LLM, _internal=True):
|
||||
temperature: float | None = None,
|
||||
**kwargs: t.Any,
|
||||
) -> t.Any:
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
if self.config.use_half_precision:
|
||||
self.model = self.model.half()
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
else:
|
||||
self.model = self.model.float()
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
self.model = self.model.to("mps")
|
||||
|
||||
self.model.eval()
|
||||
|
||||
logit_processor = LogitsProcessorList()
|
||||
|
||||
Reference in New Issue
Block a user