mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-22 16:07:24 -04:00
fix(chatglm): generation tokens not concatenated correctly
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -41,7 +41,7 @@ class ChatGLMConfig(
|
||||
Refer to [ChatGLM's GitHub page](https://github.com/THUDM/ChatGLM-6B) for more information.
|
||||
"""
|
||||
|
||||
retain_history: bool = True
|
||||
retain_history: bool = False
|
||||
"""Whether to retain history given to the model. If set to True, then the model will retain given history."""
|
||||
|
||||
use_half_precision: bool = True
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
@@ -39,32 +38,12 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
return scores
|
||||
|
||||
|
||||
def process_response(
|
||||
response: str,
|
||||
use_default_prompt_template: bool = True,
|
||||
):
|
||||
response = response.strip()
|
||||
if use_default_prompt_template:
|
||||
response = response.replace("[[训练时间]]", "2023年")
|
||||
punkts = [
|
||||
[",", ","],
|
||||
["!", "!"],
|
||||
[":", ":"],
|
||||
[";", ";"],
|
||||
["\?", "?"],
|
||||
]
|
||||
for item in punkts:
|
||||
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
||||
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
||||
return response
|
||||
|
||||
|
||||
class ChatGLM(openllm.LLM):
|
||||
__openllm_internal__ = True
|
||||
|
||||
default_model = "THUDM/chatglm-6b-int4"
|
||||
default_model = "thudm/chatglm-6b-int4"
|
||||
|
||||
pretrained = ["THUDM/chatglm-6b", "THUDM/chatglm-6b-int8", "THUDM/chatglm-6b-int4"]
|
||||
pretrained = ["thudm/chatglm-6b", "thudm/chatglm-6b-int8", "thudm/chatglm-6b-int4"]
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
@@ -121,7 +100,7 @@ class ChatGLM(openllm.LLM):
|
||||
if self.config.retain_history:
|
||||
assert chat_history is not None, "'retain_history' is True while there is no history provided."
|
||||
chat_history.append((prompt, generation_result))
|
||||
return generation_result
|
||||
return "".join(generation_result)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, prompt: str, use_default_prompt_template: bool = True, **attrs: t.Any) -> str:
|
||||
@@ -144,4 +123,4 @@ class ChatGLM(openllm.LLM):
|
||||
)
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
|
||||
response = self.tokenizer.decode(outputs)
|
||||
return process_response(response, use_default_prompt_template)
|
||||
return self.model.process_response(response)
|
||||
|
||||
Reference in New Issue
Block a user