From 392c7a8139c754e08a0808a477a8bfb71a032918 Mon Sep 17 00:00:00 2001
From: XunchaoZ <69278985+XunchaoZ@users.noreply.github.com>
Date: Mon, 30 Oct 2023 17:28:42 -0400
Subject: [PATCH] Fix chat template and message list bug (#549)
---
openllm-python/src/openllm/_conversation.py | 30 +++++++++++++++----
openllm-python/src/openllm/protocol/openai.py | 1 +
2 files changed, 25 insertions(+), 6 deletions(-)
diff --git a/openllm-python/src/openllm/_conversation.py b/openllm-python/src/openllm/_conversation.py
index 77eb71c7..2577c7af 100644
--- a/openllm-python/src/openllm/_conversation.py
+++ b/openllm-python/src/openllm/_conversation.py
@@ -6,7 +6,8 @@ from enum import auto
import attr
-if t.TYPE_CHECKING: import openllm_core
+if t.TYPE_CHECKING:
+ import openllm_core
_object_setattr = object.__setattr__
@@ -41,7 +42,7 @@ class Conversation:
# The names of two roles
roles: t.Tuple[str, str] = ('User', 'Assistant')
# All messages. Each item is (role, message).
- messages: t.List[t.List[str]] = []
+ messages: t.List[t.List[str]] = attr.Factory(list)
# The number of few shot examples
offset: int = 0
# The separator style and configurations
@@ -114,7 +115,7 @@ class Conversation:
if self.system_message:
ret = system_prompt
else:
- ret = '[INST] '
+ ret = '[INST] '
for i, (role, message) in enumerate(self.messages):
tag = self.roles[i % 2]
if message:
@@ -175,7 +176,9 @@ class Conversation:
raise ValueError(f'Invalid style: {self.sep_style}')
return ret
- def set_system_message(self, system_message: str) -> None: _object_setattr(self, 'system_message', system_message)
+ def set_system_message(self, system_message: str) -> None:
+ _object_setattr(self, 'system_message', system_message)
+
def append_message(self, role: str, message: str) -> None:
'''Append a new message.'''
self.messages.append([role, message])
@@ -199,6 +202,19 @@ class Conversation:
ret.append({'role': 'assistant', 'content': msg})
return ret
+ def copy(self) -> Conversation:
+ return Conversation(name=self.name,
+ system_template=self.system_template,
+ system_message=self.system_message,
+ roles=self.roles,
+ messages=self.messages,
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ stop_str=self.stop_str,
+ stop_token_ids=self.stop_token_ids)
+
# A global registry for all conversation templates for OpenLLM models
conv_templates: t.Dict[str, Conversation] = {}
@@ -208,7 +224,7 @@ def register_conv_template(template: Conversation) -> None:
def get_conv_template(name: str, llm_config: openllm_core.LLMConfig) -> Conversation:
if name not in conv_templates: raise ValueError(f'Failed to find conversation templates for {name}')
- template = conv_templates[name]
+ template = conv_templates[name].copy()
if hasattr(llm_config, 'default_system_message'): template.set_system_message(llm_config.default_system_message)
return template
@@ -217,7 +233,9 @@ register_conv_template(Conversation(name='raw', system_message='', roles=('', ''
# Llama template
# source: https://huggingface.co/blog/codellama#conversational-instructions
-register_conv_template(Conversation(name='llama', system_template='[INST] <>\n{system_message}\n<>\n\n', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2=' ',))
+register_conv_template(
+ Conversation(name='llama', system_template='[INST] <>\n{system_message}\n<>\n\n', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2=' ',
+ ))
# ChatGLM template
register_conv_template(Conversation(name='chatglm', roles=('问', '答'), sep_style=SeparatorStyle.CHATGLM, sep='\n',))
diff --git a/openllm-python/src/openllm/protocol/openai.py b/openllm-python/src/openllm/protocol/openai.py
index 9e00f039..371691ca 100644
--- a/openllm-python/src/openllm/protocol/openai.py
+++ b/openllm-python/src/openllm/protocol/openai.py
@@ -138,4 +138,5 @@ def messages_to_prompt(messages: list[Message], model: str, llm_config: openllm_
for message in messages:
if message['role'] == 'system': conv_template.set_system_message(message['content'])
else: conv_template.append_message(message['role'], message['content'])
+ conv_template.append_message('assistant', '')
return conv_template.get_prompt()