From 022130d0ac9fc9eb4fa7dad5f59630c81b9c0702 Mon Sep 17 00:00:00 2001
From: XunchaoZ <69278985+XunchaoZ@users.noreply.github.com>
Date: Mon, 30 Oct 2023 03:20:43 -0400
Subject: [PATCH] fix(openai): Chat templates (#519)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
openllm-python/src/openllm/_conversation.py | 291 ++++++++++++++++++
openllm-python/src/openllm/_llm.py | 15 +-
openllm-python/src/openllm/_service.py | 17 +-
openllm-python/src/openllm/protocol/openai.py | 11 +-
4 files changed, 314 insertions(+), 20 deletions(-)
create mode 100644 openllm-python/src/openllm/_conversation.py
diff --git a/openllm-python/src/openllm/_conversation.py b/openllm-python/src/openllm/_conversation.py
new file mode 100644
index 00000000..77eb71c7
--- /dev/null
+++ b/openllm-python/src/openllm/_conversation.py
@@ -0,0 +1,291 @@
+from __future__ import annotations
+import typing as t
+
+from enum import IntEnum
+from enum import auto
+
+import attr
+
+if t.TYPE_CHECKING: import openllm_core
+
+_object_setattr = object.__setattr__
+
+class SeparatorStyle(IntEnum):
+ '''Separator styles.'''
+
+ # Generic separator styles for chat models
+ ADD_COLON_SINGLE = auto()
+ ADD_COLON_TWO = auto()
+ ADD_COLON_SPACE_SINGLE = auto()
+ NO_COLON_SINGLE = auto()
+ NO_COLON_TWO = auto()
+ ADD_NEW_LINE_SINGLE = auto()
+
+ # Special separator styles for specific chat models in OpenLLM
+ LLAMA = auto()
+ CHATGLM = auto()
+ DOLLY = auto()
+ MPT = auto()
+ STARCODER = auto()
+
+@attr.define
+class Conversation:
+ '''A class that manages prompt templates and keeps all conversation history.'''
+
+ # The name of this template
+ name: str
+ # The template of the system prompt
+ system_template: str = '{system_message}'
+ # The system message
+ system_message: str = ''
+ # The names of two roles
+ roles: t.Tuple[str, str] = ('User', 'Assistant')
+ # All messages. Each item is (role, message).
+ messages: t.List[t.List[str]] = []
+ # The number of few shot examples
+ offset: int = 0
+ # The separator style and configurations
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
+ sep: str = '\n'
+ sep2: str = ''
+ # Stop criteria (the default one is EOS token)
+ stop_str: t.Union[str, t.List[str]] = ''
+ # Stops generation if meeting any token in this list
+ stop_token_ids: t.List[int] = []
+
+ def get_prompt(self) -> str:
+ '''Get the prompt for generation.'''
+ system_prompt = self.system_template.format(system_message=self.system_message)
+
+ # Generic separator styles for chat models
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: # Role with colon
+ ret = system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ': ' + message + self.sep
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: # Role with colon, two different separators for two roles
+ seps = [self.sep, self.sep2]
+ ret = system_prompt + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ': ' + message + seps[i % 2]
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: # Add a space after colon
+ ret = system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ': ' + message + self.sep
+ else:
+ ret += role + ': ' # must be end with a space
+ return ret
+ elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: # Add a new line after role
+ ret = '' if system_prompt == '' else system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + '\n' + message + self.sep
+ else:
+ ret += role + '\n'
+ return ret
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: # No colon
+ ret = system_prompt
+ for role, message in self.messages:
+ if message:
+ ret += role + message + self.sep
+ else:
+ ret += role
+ return ret
+ elif self.sep_style == SeparatorStyle.NO_COLON_TWO: # No colon, two different separators for two roles
+ seps = [self.sep, self.sep2]
+ ret = system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + message + seps[i % 2]
+ else:
+ ret += role
+ return ret
+ # Special separator styles for specific chat models
+ elif self.sep_style == SeparatorStyle.LLAMA:
+ seps = [self.sep, self.sep2]
+ if self.system_message:
+ ret = system_prompt
+ else:
+ ret = '[INST] '
+ for i, (role, message) in enumerate(self.messages):
+ tag = self.roles[i % 2]
+ if message:
+ if i == 0:
+ ret += message + ' '
+ else:
+ ret += tag + ' ' + message + seps[i % 2]
+ else:
+ ret += tag
+ return ret
+ elif self.sep_style == SeparatorStyle.CHATGLM:
+ round_add_n = 1 if self.name == 'chatglm2' else 0
+ if system_prompt:
+ ret = system_prompt + self.sep
+ else:
+ ret = ''
+ for i, (role, message) in enumerate(self.messages):
+ if i % 2 == 0:
+ ret += f'[Round {i//2 + round_add_n}]{self.sep}'
+ if message:
+ ret += f'{role}:{message}{self.sep}'
+ else:
+ ret += f'{role}:'
+ return ret
+ elif self.sep_style == SeparatorStyle.DOLLY:
+ seps = [self.sep, self.sep2]
+ ret = system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ':\n' + message + seps[i % 2]
+ if i % 2 == 1:
+ ret += '\n\n'
+ else:
+ ret += role + ':\n'
+ return ret
+ elif self.sep_style == SeparatorStyle.MPT:
+ if system_prompt:
+ ret = f'<|im_start|>system\n{system_prompt}<|im_end|>{self.sep}'
+ else:
+ ret = ''
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += f'<|im_start|>{role}\n{message}<|im_end|>{self.sep}'
+ else:
+ ret += f'{role}:'
+ return ret
+ elif self.sep_style == SeparatorStyle.STARCODER:
+ if system_prompt:
+ ret = f'<|system|>\n{system_prompt}<|end|>{self.sep}'
+ else:
+ ret = ''
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += f'{role}\n{message}<|end|>{self.sep}'
+ else:
+ ret += f'{role}:'
+ else:
+ raise ValueError(f'Invalid style: {self.sep_style}')
+ return ret
+
+ def set_system_message(self, system_message: str) -> None: _object_setattr(self, 'system_message', system_message)
+ def append_message(self, role: str, message: str) -> None:
+ '''Append a new message.'''
+ self.messages.append([role, message])
+
+ def update_last_message(self, message: str) -> None:
+ '''Update the last output.
+
+ The last message is typically set to be None when constructing the prompt,
+ so we need to update it in-place after getting the response from a model.
+ '''
+ self.messages[-1][1] = message
+
+ def to_openai_api_messages(self) -> t.List[t.Dict[str, str]]:
+ '''Convert the conversation to OpenAI chat completion format.'''
+ ret = [{'role': 'system', 'content': self.system_message}]
+
+ for i, (_, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ ret.append({'role': 'user', 'content': msg})
+ elif msg is not None:
+ ret.append({'role': 'assistant', 'content': msg})
+ return ret
+
+# A global registry for all conversation templates for OpenLLM models
+conv_templates: t.Dict[str, Conversation] = {}
+
+def register_conv_template(template: Conversation) -> None:
+ '''Register a new conversation template.'''
+ conv_templates[template.name] = template
+
+def get_conv_template(name: str, llm_config: openllm_core.LLMConfig) -> Conversation:
+ if name not in conv_templates: raise ValueError(f'Failed to find conversation templates for {name}')
+ template = conv_templates[name]
+ if hasattr(llm_config, 'default_system_message'): template.set_system_message(llm_config.default_system_message)
+ return template
+
+# Raw template
+register_conv_template(Conversation(name='raw', system_message='', roles=('', ''), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep=''))
+
+# Llama template
+# source: https://huggingface.co/blog/codellama#conversational-instructions
+register_conv_template(Conversation(name='llama', system_template='[INST] <>\n{system_message}\n<>\n\n', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2=' ',))
+
+# ChatGLM template
+register_conv_template(Conversation(name='chatglm', roles=('问', '答'), sep_style=SeparatorStyle.CHATGLM, sep='\n',))
+
+# Dolly-v2 template
+register_conv_template(
+ Conversation(name='dolly_v2',
+ system_message='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n',
+ roles=('### Instruction', '### Response'),
+ sep_style=SeparatorStyle.DOLLY,
+ sep='\n\n',
+ sep2='### End',
+ ))
+
+# Falcon template
+register_conv_template(
+ # source: https://huggingface.co/tiiuae/falcon-7b-instruct/discussions/1
+ Conversation(name='falcon', roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, # No space after colon
+ sep='\n',
+ ))
+
+# Flan-T5 default template
+register_conv_template(
+ # source: https://www.philschmid.de/fine-tune-flan-t5
+ # No specific template found, but seems to have the same dialogue style
+ Conversation(name='flan-t5', system_message='', roles=('User', 'Assistant'), sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'))
+
+# GPT-NeoX default template
+register_conv_template(
+ # source: https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B
+ # Don't know if GPT-NeoX-20B is trained on any chat prompt template
+ Conversation(name='gpt-neox', system_message='', roles=('', ''), sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, sep='\n'))
+
+# MPT template
+register_conv_template(
+ # source: https://huggingface.co/TheBloke/mpt-30B-chat-GGML/discussions/4
+ Conversation(name='mpt', roles=('user', 'assistant'), messages=[], sep_style=SeparatorStyle.MPT, sep='\n'))
+
+# OPT template (No reference for OPT found)
+register_conv_template(Conversation(name='opt', roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'))
+
+# StableLM default template
+register_conv_template(
+ Conversation(name='stablelm',
+ system_template='<|SYSTEM|>{system_message}',
+ system_message='''# StableLM Tuned (Alpha version)
+- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
+- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
+- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
+- StableLM will refuse to participate in anything that could harm a human.
+''',
+ roles=('<|USER|>', '<|ASSISTANT|>'),
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
+ sep='',
+ stop_token_ids=[50278, 50279, 50277, 1, 0],
+ ))
+
+# StarCoder default template
+register_conv_template(
+ # source: https://github.com/bigcode-project/starcoder/blob/main/chat/dialogues.py
+ Conversation(name='starcoder', system_message='', roles=('<|user|>', '<|assistant|>'), sep_style=SeparatorStyle.STARCODER, sep='\n'))
+
+# Baichuan default template
+register_conv_template(
+ # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555
+ # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
+ # https://github.com/baichuan-inc/Baichuan-13B/issues/25
+ Conversation(name='baichuan', roles=('', ''), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep=''))
+
+# Mistral template
+register_conv_template(Conversation(name='mistral', system_message='', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2='',))
diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py
index dff0eb6d..7e473ad4 100644
--- a/openllm-python/src/openllm/_llm.py
+++ b/openllm-python/src/openllm/_llm.py
@@ -1149,21 +1149,24 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
def generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
- prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
+ use_chat_template: bool = attrs.pop('_format_chat_template', False)
+ if not use_chat_template: prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
return self.generate(prompt, **attrs)
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
def generate_one(__self: _Runnable, prompt: str, stop: list[str], **attrs: t.Any) -> t.Sequence[dict[t.Literal['generated_text'], str]]:
- prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
+ use_chat_template: bool = attrs.pop('_format_chat_template', False)
+ if not use_chat_template: prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
return self.generate_one(prompt, stop, **attrs)
@bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore
def generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.Generator[str, None, str]:
- prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
+ use_chat_template: bool = attrs.pop('_format_chat_template', False)
+ if not use_chat_template: prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
pre = 0
@@ -1178,6 +1181,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
async def vllm_generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[list[t.Any], None]:
+ use_chat_template: bool = attrs.pop('_format_chat_template', False)
stop: str | t.Iterable[str] | None = attrs.pop('stop', None)
echo = attrs.pop('echo', False)
stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None)
@@ -1187,7 +1191,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i
if adapter_name is not None: __self.set_adapter(adapter_name)
request_id: str | None = attrs.pop('request_id', None)
if request_id is None: raise ValueError('request_id must not be None.')
- prompt, *_ = self.sanitize_parameters(prompt, **attrs)
+ if not use_chat_template: prompt, *_ = self.sanitize_parameters(prompt, **attrs)
if openllm_core.utils.DEBUG: logger.debug('Prompt:\n%s', prompt)
if stop_token_ids is None: stop_token_ids = []
@@ -1213,6 +1217,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i
@bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore
async def vllm_generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[str, None]:
+ use_chat_template: bool = attrs.pop('_format_chat_template', False)
# TODO: System prompt support
pre = 0
echo = attrs.pop('echo', False)
@@ -1224,7 +1229,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i
if adapter_name is not None: __self.set_adapter(adapter_name)
request_id: str | None = attrs.pop('request_id', None)
if request_id is None: raise ValueError('request_id must not be None.')
- prompt, *_ = self.sanitize_parameters(prompt, **attrs)
+ if not use_chat_template: prompt, *_ = self.sanitize_parameters(prompt, **attrs)
if openllm_core.utils.DEBUG: logger.debug('Prompt:\n%s', prompt)
if stop_token_ids is None: stop_token_ids = []
diff --git a/openllm-python/src/openllm/_service.py b/openllm-python/src/openllm/_service.py
index 8ebe0e6a..092fe416 100644
--- a/openllm-python/src/openllm/_service.py
+++ b/openllm-python/src/openllm/_service.py
@@ -66,7 +66,7 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s
output=bentoml.io.Text())
async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
_model = input_dict.get('model', None)
- if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
+ if _model != runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
prompt = input_dict.pop('prompt', None)
if prompt is None: raise ValueError("'prompt' should not be None.")
stream = input_dict.pop('stream', False)
@@ -112,20 +112,12 @@ async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> s
)).decode()
@svc.api(route='/v1/chat/completions',
- input=bentoml.io.JSON.from_sample(
- openllm.utils.bentoml_cattr.unstructure(
- openllm.openai.ChatCompletionRequest(messages=[{
- 'role': 'system',
- 'content': 'You are a helpful assistant.'
- }, {
- 'role': 'user',
- 'content': 'Hello!'
- }], model=runner.llm_type))),
+ input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.ChatCompletionRequest(messages=[{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello!'}], model=runner.llm_type))),
output=bentoml.io.Text())
async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
_model = input_dict.get('model', None)
- if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
- prompt = openllm.openai.messages_to_prompt(input_dict['messages'])
+ if _model != runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
+ prompt = openllm.openai.messages_to_prompt(input_dict['messages'], model, llm_config)
stream = input_dict.pop('stream', False)
config = {
'temperature': input_dict.pop('temperature', llm_config['temperature']),
@@ -136,6 +128,7 @@ async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context)
'max_new_tokens': input_dict.pop('max_tokens', llm_config['max_new_tokens']),
'presence_penalty': input_dict.pop('presence_penalty', llm_config['presence_penalty']),
'frequency_penalty': input_dict.pop('frequency_penalty', llm_config['frequency_penalty']),
+ '_format_chat_template': True,
}
async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t.AsyncGenerator[str, None]:
diff --git a/openllm-python/src/openllm/protocol/openai.py b/openllm-python/src/openllm/protocol/openai.py
index 7f3e3783..9e00f039 100644
--- a/openllm-python/src/openllm/protocol/openai.py
+++ b/openllm-python/src/openllm/protocol/openai.py
@@ -6,6 +6,8 @@ import attr
import openllm_core
+from openllm import _conversation
+
@attr.define
class CompletionRequest:
prompt: str
@@ -131,6 +133,9 @@ class ModelList:
object: str = 'list'
data: t.List[ModelCard] = attr.field(factory=list)
-def messages_to_prompt(messages: list[Message]) -> str:
- formatted = '\n'.join([f"{message['role']}: {message['content']}" for message in messages])
- return f'{formatted}\nassistant:'
+def messages_to_prompt(messages: list[Message], model: str, llm_config: openllm_core.LLMConfig) -> str:
+ conv_template = _conversation.get_conv_template(model, llm_config)
+ for message in messages:
+ if message['role'] == 'system': conv_template.set_system_message(message['content'])
+ else: conv_template.append_message(message['role'], message['content'])
+ return conv_template.get_prompt()