From 392c7a8139c754e08a0808a477a8bfb71a032918 Mon Sep 17 00:00:00 2001 From: XunchaoZ <69278985+XunchaoZ@users.noreply.github.com> Date: Mon, 30 Oct 2023 17:28:42 -0400 Subject: [PATCH] Fix chat template and message list bug (#549) --- openllm-python/src/openllm/_conversation.py | 30 +++++++++++++++---- openllm-python/src/openllm/protocol/openai.py | 1 + 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/openllm-python/src/openllm/_conversation.py b/openllm-python/src/openllm/_conversation.py index 77eb71c7..2577c7af 100644 --- a/openllm-python/src/openllm/_conversation.py +++ b/openllm-python/src/openllm/_conversation.py @@ -6,7 +6,8 @@ from enum import auto import attr -if t.TYPE_CHECKING: import openllm_core +if t.TYPE_CHECKING: + import openllm_core _object_setattr = object.__setattr__ @@ -41,7 +42,7 @@ class Conversation: # The names of two roles roles: t.Tuple[str, str] = ('User', 'Assistant') # All messages. Each item is (role, message). - messages: t.List[t.List[str]] = [] + messages: t.List[t.List[str]] = attr.Factory(list) # The number of few shot examples offset: int = 0 # The separator style and configurations @@ -114,7 +115,7 @@ class Conversation: if self.system_message: ret = system_prompt else: - ret = '[INST] ' + ret = '[INST] ' for i, (role, message) in enumerate(self.messages): tag = self.roles[i % 2] if message: @@ -175,7 +176,9 @@ class Conversation: raise ValueError(f'Invalid style: {self.sep_style}') return ret - def set_system_message(self, system_message: str) -> None: _object_setattr(self, 'system_message', system_message) + def set_system_message(self, system_message: str) -> None: + _object_setattr(self, 'system_message', system_message) + def append_message(self, role: str, message: str) -> None: '''Append a new message.''' self.messages.append([role, message]) @@ -199,6 +202,19 @@ class Conversation: ret.append({'role': 'assistant', 'content': msg}) return ret + def copy(self) -> Conversation: + return Conversation(name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=self.messages, + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids) + # A global registry for all conversation templates for OpenLLM models conv_templates: t.Dict[str, Conversation] = {} @@ -208,7 +224,7 @@ def register_conv_template(template: Conversation) -> None: def get_conv_template(name: str, llm_config: openllm_core.LLMConfig) -> Conversation: if name not in conv_templates: raise ValueError(f'Failed to find conversation templates for {name}') - template = conv_templates[name] + template = conv_templates[name].copy() if hasattr(llm_config, 'default_system_message'): template.set_system_message(llm_config.default_system_message) return template @@ -217,7 +233,9 @@ register_conv_template(Conversation(name='raw', system_message='', roles=('', '' # Llama template # source: https://huggingface.co/blog/codellama#conversational-instructions -register_conv_template(Conversation(name='llama', system_template='[INST] <>\n{system_message}\n<>\n\n', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2=' ',)) +register_conv_template( + Conversation(name='llama', system_template='[INST] <>\n{system_message}\n<>\n\n', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2=' ', + )) # ChatGLM template register_conv_template(Conversation(name='chatglm', roles=('问', '答'), sep_style=SeparatorStyle.CHATGLM, sep='\n',)) diff --git a/openllm-python/src/openllm/protocol/openai.py b/openllm-python/src/openllm/protocol/openai.py index 9e00f039..371691ca 100644 --- a/openllm-python/src/openllm/protocol/openai.py +++ b/openllm-python/src/openllm/protocol/openai.py @@ -138,4 +138,5 @@ def messages_to_prompt(messages: list[Message], model: str, llm_config: openllm_ for message in messages: if message['role'] == 'system': conv_template.set_system_message(message['content']) else: conv_template.append_message(message['role'], message['content']) + conv_template.append_message('assistant', '') return conv_template.get_prompt()