fix(chatglm): generation tokens not concatenated correctly

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-06-10 09:46:33 +00:00
parent d70530cb0e
commit ebfed3c116
2 changed files with 5 additions and 26 deletions

View File

@@ -41,7 +41,7 @@ class ChatGLMConfig(
Refer to [ChatGLM's GitHub page](https://github.com/THUDM/ChatGLM-6B) for more information.
"""
retain_history: bool = True
retain_history: bool = False
"""Whether to retain history given to the model. If set to True, then the model will retain given history."""
use_half_precision: bool = True

View File

@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import re
import typing as t
import bentoml
@@ -39,32 +38,12 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
return scores
def process_response(
response: str,
use_default_prompt_template: bool = True,
):
response = response.strip()
if use_default_prompt_template:
response = response.replace("[[训练时间]]", "2023年")
punkts = [
[",", ""],
["!", ""],
[":", ""],
[";", ""],
["\?", ""],
]
for item in punkts:
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
return response
class ChatGLM(openllm.LLM):
__openllm_internal__ = True
default_model = "THUDM/chatglm-6b-int4"
default_model = "thudm/chatglm-6b-int4"
pretrained = ["THUDM/chatglm-6b", "THUDM/chatglm-6b-int8", "THUDM/chatglm-6b-int4"]
pretrained = ["thudm/chatglm-6b", "thudm/chatglm-6b-int8", "thudm/chatglm-6b-int4"]
device = torch.device("cuda")
@@ -121,7 +100,7 @@ class ChatGLM(openllm.LLM):
if self.config.retain_history:
assert chat_history is not None, "'retain_history' is True while there is no history provided."
chat_history.append((prompt, generation_result))
return generation_result
return "".join(generation_result)
@torch.inference_mode()
def generate(self, prompt: str, use_default_prompt_template: bool = True, **attrs: t.Any) -> str:
@@ -144,4 +123,4 @@ class ChatGLM(openllm.LLM):
)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = self.tokenizer.decode(outputs)
return process_response(response, use_default_prompt_template)
return self.model.process_response(response)