Fix chat template and message list bug (#549)

This commit is contained in:
XunchaoZ
2023-10-30 17:28:42 -04:00
committed by GitHub
parent 050b209334
commit 392c7a8139
2 changed files with 25 additions and 6 deletions

View File

@@ -6,7 +6,8 @@ from enum import auto
import attr
if t.TYPE_CHECKING: import openllm_core
if t.TYPE_CHECKING:
import openllm_core
_object_setattr = object.__setattr__
@@ -41,7 +42,7 @@ class Conversation:
# The names of two roles
roles: t.Tuple[str, str] = ('User', 'Assistant')
# All messages. Each item is (role, message).
messages: t.List[t.List[str]] = []
messages: t.List[t.List[str]] = attr.Factory(list)
# The number of few shot examples
offset: int = 0
# The separator style and configurations
@@ -114,7 +115,7 @@ class Conversation:
if self.system_message:
ret = system_prompt
else:
ret = '[INST] '
ret = '<s>[INST] '
for i, (role, message) in enumerate(self.messages):
tag = self.roles[i % 2]
if message:
@@ -175,7 +176,9 @@ class Conversation:
raise ValueError(f'Invalid style: {self.sep_style}')
return ret
def set_system_message(self, system_message: str) -> None: _object_setattr(self, 'system_message', system_message)
def set_system_message(self, system_message: str) -> None:
_object_setattr(self, 'system_message', system_message)
def append_message(self, role: str, message: str) -> None:
'''Append a new message.'''
self.messages.append([role, message])
@@ -199,6 +202,19 @@ class Conversation:
ret.append({'role': 'assistant', 'content': msg})
return ret
def copy(self) -> Conversation:
return Conversation(name=self.name,
system_template=self.system_template,
system_message=self.system_message,
roles=self.roles,
messages=self.messages,
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids)
# A global registry for all conversation templates for OpenLLM models
conv_templates: t.Dict[str, Conversation] = {}
@@ -208,7 +224,7 @@ def register_conv_template(template: Conversation) -> None:
def get_conv_template(name: str, llm_config: openllm_core.LLMConfig) -> Conversation:
if name not in conv_templates: raise ValueError(f'Failed to find conversation templates for {name}')
template = conv_templates[name]
template = conv_templates[name].copy()
if hasattr(llm_config, 'default_system_message'): template.set_system_message(llm_config.default_system_message)
return template
@@ -217,7 +233,9 @@ register_conv_template(Conversation(name='raw', system_message='', roles=('', ''
# Llama template
# source: https://huggingface.co/blog/codellama#conversational-instructions
register_conv_template(Conversation(name='llama', system_template='[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2=' </s><s>',))
register_conv_template(
Conversation(name='llama', system_template='<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2=' </s><s>',
))
# ChatGLM template
register_conv_template(Conversation(name='chatglm', roles=('', ''), sep_style=SeparatorStyle.CHATGLM, sep='\n',))

View File

@@ -138,4 +138,5 @@ def messages_to_prompt(messages: list[Message], model: str, llm_config: openllm_
for message in messages:
if message['role'] == 'system': conv_template.set_system_message(message['content'])
else: conv_template.append_message(message['role'], message['content'])
conv_template.append_message('assistant', '')
return conv_template.get_prompt()