mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-21 15:39:36 -04:00
fix(openai): Chat templates (#519)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Co-authored-by: Aaron <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
291
openllm-python/src/openllm/_conversation.py
Normal file
291
openllm-python/src/openllm/_conversation.py
Normal file
@@ -0,0 +1,291 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
from enum import IntEnum
|
||||
from enum import auto
|
||||
|
||||
import attr
|
||||
|
||||
if t.TYPE_CHECKING: import openllm_core
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
class SeparatorStyle(IntEnum):
|
||||
'''Separator styles.'''
|
||||
|
||||
# Generic separator styles for chat models
|
||||
ADD_COLON_SINGLE = auto()
|
||||
ADD_COLON_TWO = auto()
|
||||
ADD_COLON_SPACE_SINGLE = auto()
|
||||
NO_COLON_SINGLE = auto()
|
||||
NO_COLON_TWO = auto()
|
||||
ADD_NEW_LINE_SINGLE = auto()
|
||||
|
||||
# Special separator styles for specific chat models in OpenLLM
|
||||
LLAMA = auto()
|
||||
CHATGLM = auto()
|
||||
DOLLY = auto()
|
||||
MPT = auto()
|
||||
STARCODER = auto()
|
||||
|
||||
@attr.define
|
||||
class Conversation:
|
||||
'''A class that manages prompt templates and keeps all conversation history.'''
|
||||
|
||||
# The name of this template
|
||||
name: str
|
||||
# The template of the system prompt
|
||||
system_template: str = '{system_message}'
|
||||
# The system message
|
||||
system_message: str = ''
|
||||
# The names of two roles
|
||||
roles: t.Tuple[str, str] = ('User', 'Assistant')
|
||||
# All messages. Each item is (role, message).
|
||||
messages: t.List[t.List[str]] = []
|
||||
# The number of few shot examples
|
||||
offset: int = 0
|
||||
# The separator style and configurations
|
||||
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
|
||||
sep: str = '\n'
|
||||
sep2: str = ''
|
||||
# Stop criteria (the default one is EOS token)
|
||||
stop_str: t.Union[str, t.List[str]] = ''
|
||||
# Stops generation if meeting any token in this list
|
||||
stop_token_ids: t.List[int] = []
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
'''Get the prompt for generation.'''
|
||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||
|
||||
# Generic separator styles for chat models
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: # Role with colon
|
||||
ret = system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ': ' + message + self.sep
|
||||
else:
|
||||
ret += role + ':'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: # Role with colon, two different separators for two roles
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = system_prompt + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ': ' + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ':'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: # Add a space after colon
|
||||
ret = system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ': ' + message + self.sep
|
||||
else:
|
||||
ret += role + ': ' # must be end with a space
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: # Add a new line after role
|
||||
ret = '' if system_prompt == '' else system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + '\n' + message + self.sep
|
||||
else:
|
||||
ret += role + '\n'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: # No colon
|
||||
ret = system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + message + self.sep
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.NO_COLON_TWO: # No colon, two different separators for two roles
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + message + seps[i % 2]
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
# Special separator styles for specific chat models
|
||||
elif self.sep_style == SeparatorStyle.LLAMA:
|
||||
seps = [self.sep, self.sep2]
|
||||
if self.system_message:
|
||||
ret = system_prompt
|
||||
else:
|
||||
ret = '[INST] '
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
tag = self.roles[i % 2]
|
||||
if message:
|
||||
if i == 0:
|
||||
ret += message + ' '
|
||||
else:
|
||||
ret += tag + ' ' + message + seps[i % 2]
|
||||
else:
|
||||
ret += tag
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.CHATGLM:
|
||||
round_add_n = 1 if self.name == 'chatglm2' else 0
|
||||
if system_prompt:
|
||||
ret = system_prompt + self.sep
|
||||
else:
|
||||
ret = ''
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if i % 2 == 0:
|
||||
ret += f'[Round {i//2 + round_add_n}]{self.sep}'
|
||||
if message:
|
||||
ret += f'{role}:{message}{self.sep}'
|
||||
else:
|
||||
ret += f'{role}:'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.DOLLY:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ':\n' + message + seps[i % 2]
|
||||
if i % 2 == 1:
|
||||
ret += '\n\n'
|
||||
else:
|
||||
ret += role + ':\n'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.MPT:
|
||||
if system_prompt:
|
||||
ret = f'<|im_start|>system\n{system_prompt}<|im_end|>{self.sep}'
|
||||
else:
|
||||
ret = ''
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += f'<|im_start|>{role}\n{message}<|im_end|>{self.sep}'
|
||||
else:
|
||||
ret += f'{role}:'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.STARCODER:
|
||||
if system_prompt:
|
||||
ret = f'<|system|>\n{system_prompt}<|end|>{self.sep}'
|
||||
else:
|
||||
ret = ''
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += f'{role}\n{message}<|end|>{self.sep}'
|
||||
else:
|
||||
ret += f'{role}:'
|
||||
else:
|
||||
raise ValueError(f'Invalid style: {self.sep_style}')
|
||||
return ret
|
||||
|
||||
def set_system_message(self, system_message: str) -> None: _object_setattr(self, 'system_message', system_message)
|
||||
def append_message(self, role: str, message: str) -> None:
|
||||
'''Append a new message.'''
|
||||
self.messages.append([role, message])
|
||||
|
||||
def update_last_message(self, message: str) -> None:
|
||||
'''Update the last output.
|
||||
|
||||
The last message is typically set to be None when constructing the prompt,
|
||||
so we need to update it in-place after getting the response from a model.
|
||||
'''
|
||||
self.messages[-1][1] = message
|
||||
|
||||
def to_openai_api_messages(self) -> t.List[t.Dict[str, str]]:
|
||||
'''Convert the conversation to OpenAI chat completion format.'''
|
||||
ret = [{'role': 'system', 'content': self.system_message}]
|
||||
|
||||
for i, (_, msg) in enumerate(self.messages[self.offset:]):
|
||||
if i % 2 == 0:
|
||||
ret.append({'role': 'user', 'content': msg})
|
||||
elif msg is not None:
|
||||
ret.append({'role': 'assistant', 'content': msg})
|
||||
return ret
|
||||
|
||||
# A global registry for all conversation templates for OpenLLM models
|
||||
conv_templates: t.Dict[str, Conversation] = {}
|
||||
|
||||
def register_conv_template(template: Conversation) -> None:
|
||||
'''Register a new conversation template.'''
|
||||
conv_templates[template.name] = template
|
||||
|
||||
def get_conv_template(name: str, llm_config: openllm_core.LLMConfig) -> Conversation:
|
||||
if name not in conv_templates: raise ValueError(f'Failed to find conversation templates for {name}')
|
||||
template = conv_templates[name]
|
||||
if hasattr(llm_config, 'default_system_message'): template.set_system_message(llm_config.default_system_message)
|
||||
return template
|
||||
|
||||
# Raw template
|
||||
register_conv_template(Conversation(name='raw', system_message='', roles=('', ''), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep=''))
|
||||
|
||||
# Llama template
|
||||
# source: https://huggingface.co/blog/codellama#conversational-instructions
|
||||
register_conv_template(Conversation(name='llama', system_template='[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2=' </s><s>',))
|
||||
|
||||
# ChatGLM template
|
||||
register_conv_template(Conversation(name='chatglm', roles=('问', '答'), sep_style=SeparatorStyle.CHATGLM, sep='\n',))
|
||||
|
||||
# Dolly-v2 template
|
||||
register_conv_template(
|
||||
Conversation(name='dolly_v2',
|
||||
system_message='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n',
|
||||
roles=('### Instruction', '### Response'),
|
||||
sep_style=SeparatorStyle.DOLLY,
|
||||
sep='\n\n',
|
||||
sep2='### End',
|
||||
))
|
||||
|
||||
# Falcon template
|
||||
register_conv_template(
|
||||
# source: https://huggingface.co/tiiuae/falcon-7b-instruct/discussions/1
|
||||
Conversation(name='falcon', roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, # No space after colon
|
||||
sep='\n',
|
||||
))
|
||||
|
||||
# Flan-T5 default template
|
||||
register_conv_template(
|
||||
# source: https://www.philschmid.de/fine-tune-flan-t5
|
||||
# No specific template found, but seems to have the same dialogue style
|
||||
Conversation(name='flan-t5', system_message='', roles=('User', 'Assistant'), sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'))
|
||||
|
||||
# GPT-NeoX default template
|
||||
register_conv_template(
|
||||
# source: https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B
|
||||
# Don't know if GPT-NeoX-20B is trained on any chat prompt template
|
||||
Conversation(name='gpt-neox', system_message='', roles=('<human>', '<bot>'), sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, sep='\n'))
|
||||
|
||||
# MPT template
|
||||
register_conv_template(
|
||||
# source: https://huggingface.co/TheBloke/mpt-30B-chat-GGML/discussions/4
|
||||
Conversation(name='mpt', roles=('user', 'assistant'), messages=[], sep_style=SeparatorStyle.MPT, sep='\n'))
|
||||
|
||||
# OPT template (No reference for OPT found)
|
||||
register_conv_template(Conversation(name='opt', roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'))
|
||||
|
||||
# StableLM default template
|
||||
register_conv_template(
|
||||
Conversation(name='stablelm',
|
||||
system_template='<|SYSTEM|>{system_message}',
|
||||
system_message='''# StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
''',
|
||||
roles=('<|USER|>', '<|ASSISTANT|>'),
|
||||
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
||||
sep='',
|
||||
stop_token_ids=[50278, 50279, 50277, 1, 0],
|
||||
))
|
||||
|
||||
# StarCoder default template
|
||||
register_conv_template(
|
||||
# source: https://github.com/bigcode-project/starcoder/blob/main/chat/dialogues.py
|
||||
Conversation(name='starcoder', system_message='', roles=('<|user|>', '<|assistant|>'), sep_style=SeparatorStyle.STARCODER, sep='\n'))
|
||||
|
||||
# Baichuan default template
|
||||
register_conv_template(
|
||||
# source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555
|
||||
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
|
||||
# https://github.com/baichuan-inc/Baichuan-13B/issues/25
|
||||
Conversation(name='baichuan', roles=('<reserved_102>', '<reserved_103>'), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep=''))
|
||||
|
||||
# Mistral template
|
||||
register_conv_template(Conversation(name='mistral', system_message='', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2='</s>',))
|
||||
@@ -1149,21 +1149,24 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
|
||||
def generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
|
||||
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
use_chat_template: bool = attrs.pop('_format_chat_template', False)
|
||||
if not use_chat_template: prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
adapter_name = attrs.pop('adapter_name', None)
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
return self.generate(prompt, **attrs)
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
|
||||
def generate_one(__self: _Runnable, prompt: str, stop: list[str], **attrs: t.Any) -> t.Sequence[dict[t.Literal['generated_text'], str]]:
|
||||
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
use_chat_template: bool = attrs.pop('_format_chat_template', False)
|
||||
if not use_chat_template: prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
adapter_name = attrs.pop('adapter_name', None)
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
return self.generate_one(prompt, stop, **attrs)
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore
|
||||
def generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.Generator[str, None, str]:
|
||||
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
use_chat_template: bool = attrs.pop('_format_chat_template', False)
|
||||
if not use_chat_template: prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
adapter_name = attrs.pop('adapter_name', None)
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
pre = 0
|
||||
@@ -1178,6 +1181,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
|
||||
async def vllm_generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[list[t.Any], None]:
|
||||
use_chat_template: bool = attrs.pop('_format_chat_template', False)
|
||||
stop: str | t.Iterable[str] | None = attrs.pop('stop', None)
|
||||
echo = attrs.pop('echo', False)
|
||||
stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None)
|
||||
@@ -1187,7 +1191,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
request_id: str | None = attrs.pop('request_id', None)
|
||||
if request_id is None: raise ValueError('request_id must not be None.')
|
||||
prompt, *_ = self.sanitize_parameters(prompt, **attrs)
|
||||
if not use_chat_template: prompt, *_ = self.sanitize_parameters(prompt, **attrs)
|
||||
if openllm_core.utils.DEBUG: logger.debug('Prompt:\n%s', prompt)
|
||||
|
||||
if stop_token_ids is None: stop_token_ids = []
|
||||
@@ -1213,6 +1217,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore
|
||||
async def vllm_generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[str, None]:
|
||||
use_chat_template: bool = attrs.pop('_format_chat_template', False)
|
||||
# TODO: System prompt support
|
||||
pre = 0
|
||||
echo = attrs.pop('echo', False)
|
||||
@@ -1224,7 +1229,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
request_id: str | None = attrs.pop('request_id', None)
|
||||
if request_id is None: raise ValueError('request_id must not be None.')
|
||||
prompt, *_ = self.sanitize_parameters(prompt, **attrs)
|
||||
if not use_chat_template: prompt, *_ = self.sanitize_parameters(prompt, **attrs)
|
||||
if openllm_core.utils.DEBUG: logger.debug('Prompt:\n%s', prompt)
|
||||
|
||||
if stop_token_ids is None: stop_token_ids = []
|
||||
|
||||
@@ -66,7 +66,7 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s
|
||||
output=bentoml.io.Text())
|
||||
async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
|
||||
_model = input_dict.get('model', None)
|
||||
if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
|
||||
if _model != runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
|
||||
prompt = input_dict.pop('prompt', None)
|
||||
if prompt is None: raise ValueError("'prompt' should not be None.")
|
||||
stream = input_dict.pop('stream', False)
|
||||
@@ -112,20 +112,12 @@ async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> s
|
||||
)).decode()
|
||||
|
||||
@svc.api(route='/v1/chat/completions',
|
||||
input=bentoml.io.JSON.from_sample(
|
||||
openllm.utils.bentoml_cattr.unstructure(
|
||||
openllm.openai.ChatCompletionRequest(messages=[{
|
||||
'role': 'system',
|
||||
'content': 'You are a helpful assistant.'
|
||||
}, {
|
||||
'role': 'user',
|
||||
'content': 'Hello!'
|
||||
}], model=runner.llm_type))),
|
||||
input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.ChatCompletionRequest(messages=[{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello!'}], model=runner.llm_type))),
|
||||
output=bentoml.io.Text())
|
||||
async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
|
||||
_model = input_dict.get('model', None)
|
||||
if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
|
||||
prompt = openllm.openai.messages_to_prompt(input_dict['messages'])
|
||||
if _model != runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
|
||||
prompt = openllm.openai.messages_to_prompt(input_dict['messages'], model, llm_config)
|
||||
stream = input_dict.pop('stream', False)
|
||||
config = {
|
||||
'temperature': input_dict.pop('temperature', llm_config['temperature']),
|
||||
@@ -136,6 +128,7 @@ async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context)
|
||||
'max_new_tokens': input_dict.pop('max_tokens', llm_config['max_new_tokens']),
|
||||
'presence_penalty': input_dict.pop('presence_penalty', llm_config['presence_penalty']),
|
||||
'frequency_penalty': input_dict.pop('frequency_penalty', llm_config['frequency_penalty']),
|
||||
'_format_chat_template': True,
|
||||
}
|
||||
|
||||
async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t.AsyncGenerator[str, None]:
|
||||
|
||||
@@ -6,6 +6,8 @@ import attr
|
||||
|
||||
import openllm_core
|
||||
|
||||
from openllm import _conversation
|
||||
|
||||
@attr.define
|
||||
class CompletionRequest:
|
||||
prompt: str
|
||||
@@ -131,6 +133,9 @@ class ModelList:
|
||||
object: str = 'list'
|
||||
data: t.List[ModelCard] = attr.field(factory=list)
|
||||
|
||||
def messages_to_prompt(messages: list[Message]) -> str:
|
||||
formatted = '\n'.join([f"{message['role']}: {message['content']}" for message in messages])
|
||||
return f'{formatted}\nassistant:'
|
||||
def messages_to_prompt(messages: list[Message], model: str, llm_config: openllm_core.LLMConfig) -> str:
|
||||
conv_template = _conversation.get_conv_template(model, llm_config)
|
||||
for message in messages:
|
||||
if message['role'] == 'system': conv_template.set_system_message(message['content'])
|
||||
else: conv_template.append_message(message['role'], message['content'])
|
||||
return conv_template.get_prompt()
|
||||
|
||||
Reference in New Issue
Block a user