diff --git a/openllm-python/src/openllm/_conversation.py b/openllm-python/src/openllm/_conversation.py new file mode 100644 index 00000000..77eb71c7 --- /dev/null +++ b/openllm-python/src/openllm/_conversation.py @@ -0,0 +1,291 @@ +from __future__ import annotations +import typing as t + +from enum import IntEnum +from enum import auto + +import attr + +if t.TYPE_CHECKING: import openllm_core + +_object_setattr = object.__setattr__ + +class SeparatorStyle(IntEnum): + '''Separator styles.''' + + # Generic separator styles for chat models + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + + # Special separator styles for specific chat models in OpenLLM + LLAMA = auto() + CHATGLM = auto() + DOLLY = auto() + MPT = auto() + STARCODER = auto() + +@attr.define +class Conversation: + '''A class that manages prompt templates and keeps all conversation history.''' + + # The name of this template + name: str + # The template of the system prompt + system_template: str = '{system_message}' + # The system message + system_message: str = '' + # The names of two roles + roles: t.Tuple[str, str] = ('User', 'Assistant') + # All messages. Each item is (role, message). + messages: t.List[t.List[str]] = [] + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = '\n' + sep2: str = '' + # Stop criteria (the default one is EOS token) + stop_str: t.Union[str, t.List[str]] = '' + # Stops generation if meeting any token in this list + stop_token_ids: t.List[int] = [] + + def get_prompt(self) -> str: + '''Get the prompt for generation.''' + system_prompt = self.system_template.format(system_message=self.system_message) + + # Generic separator styles for chat models + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: # Role with colon + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: # Role with colon, two different separators for two roles + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ': ' + message + seps[i % 2] + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: # Add a space after colon + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ': ' # must be end with a space + return ret + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: # Add a new line after role + ret = '' if system_prompt == '' else system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + '\n' + message + self.sep + else: + ret += role + '\n' + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: # No colon + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: # No colon, two different separators for two roles + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + # Special separator styles for specific chat models + elif self.sep_style == SeparatorStyle.LLAMA: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = '[INST] ' + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if i == 0: + ret += message + ' ' + else: + ret += tag + ' ' + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.CHATGLM: + round_add_n = 1 if self.name == 'chatglm2' else 0 + if system_prompt: + ret = system_prompt + self.sep + else: + ret = '' + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += f'[Round {i//2 + round_add_n}]{self.sep}' + if message: + ret += f'{role}:{message}{self.sep}' + else: + ret += f'{role}:' + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ':\n' + message + seps[i % 2] + if i % 2 == 1: + ret += '\n\n' + else: + ret += role + ':\n' + return ret + elif self.sep_style == SeparatorStyle.MPT: + if system_prompt: + ret = f'<|im_start|>system\n{system_prompt}<|im_end|>{self.sep}' + else: + ret = '' + for i, (role, message) in enumerate(self.messages): + if message: + ret += f'<|im_start|>{role}\n{message}<|im_end|>{self.sep}' + else: + ret += f'{role}:' + return ret + elif self.sep_style == SeparatorStyle.STARCODER: + if system_prompt: + ret = f'<|system|>\n{system_prompt}<|end|>{self.sep}' + else: + ret = '' + for i, (role, message) in enumerate(self.messages): + if message: + ret += f'{role}\n{message}<|end|>{self.sep}' + else: + ret += f'{role}:' + else: + raise ValueError(f'Invalid style: {self.sep_style}') + return ret + + def set_system_message(self, system_message: str) -> None: _object_setattr(self, 'system_message', system_message) + def append_message(self, role: str, message: str) -> None: + '''Append a new message.''' + self.messages.append([role, message]) + + def update_last_message(self, message: str) -> None: + '''Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + ''' + self.messages[-1][1] = message + + def to_openai_api_messages(self) -> t.List[t.Dict[str, str]]: + '''Convert the conversation to OpenAI chat completion format.''' + ret = [{'role': 'system', 'content': self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + ret.append({'role': 'user', 'content': msg}) + elif msg is not None: + ret.append({'role': 'assistant', 'content': msg}) + return ret + +# A global registry for all conversation templates for OpenLLM models +conv_templates: t.Dict[str, Conversation] = {} + +def register_conv_template(template: Conversation) -> None: + '''Register a new conversation template.''' + conv_templates[template.name] = template + +def get_conv_template(name: str, llm_config: openllm_core.LLMConfig) -> Conversation: + if name not in conv_templates: raise ValueError(f'Failed to find conversation templates for {name}') + template = conv_templates[name] + if hasattr(llm_config, 'default_system_message'): template.set_system_message(llm_config.default_system_message) + return template + +# Raw template +register_conv_template(Conversation(name='raw', system_message='', roles=('', ''), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep='')) + +# Llama template +# source: https://huggingface.co/blog/codellama#conversational-instructions +register_conv_template(Conversation(name='llama', system_template='[INST] <>\n{system_message}\n<>\n\n', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2=' ',)) + +# ChatGLM template +register_conv_template(Conversation(name='chatglm', roles=('问', '答'), sep_style=SeparatorStyle.CHATGLM, sep='\n',)) + +# Dolly-v2 template +register_conv_template( + Conversation(name='dolly_v2', + system_message='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n', + roles=('### Instruction', '### Response'), + sep_style=SeparatorStyle.DOLLY, + sep='\n\n', + sep2='### End', + )) + +# Falcon template +register_conv_template( + # source: https://huggingface.co/tiiuae/falcon-7b-instruct/discussions/1 + Conversation(name='falcon', roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, # No space after colon + sep='\n', + )) + +# Flan-T5 default template +register_conv_template( + # source: https://www.philschmid.de/fine-tune-flan-t5 + # No specific template found, but seems to have the same dialogue style + Conversation(name='flan-t5', system_message='', roles=('User', 'Assistant'), sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n')) + +# GPT-NeoX default template +register_conv_template( + # source: https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B + # Don't know if GPT-NeoX-20B is trained on any chat prompt template + Conversation(name='gpt-neox', system_message='', roles=('', ''), sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, sep='\n')) + +# MPT template +register_conv_template( + # source: https://huggingface.co/TheBloke/mpt-30B-chat-GGML/discussions/4 + Conversation(name='mpt', roles=('user', 'assistant'), messages=[], sep_style=SeparatorStyle.MPT, sep='\n')) + +# OPT template (No reference for OPT found) +register_conv_template(Conversation(name='opt', roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n')) + +# StableLM default template +register_conv_template( + Conversation(name='stablelm', + system_template='<|SYSTEM|>{system_message}', + system_message='''# StableLM Tuned (Alpha version) +- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. +- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. +- StableLM will refuse to participate in anything that could harm a human. +''', + roles=('<|USER|>', '<|ASSISTANT|>'), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep='', + stop_token_ids=[50278, 50279, 50277, 1, 0], + )) + +# StarCoder default template +register_conv_template( + # source: https://github.com/bigcode-project/starcoder/blob/main/chat/dialogues.py + Conversation(name='starcoder', system_message='', roles=('<|user|>', '<|assistant|>'), sep_style=SeparatorStyle.STARCODER, sep='\n')) + +# Baichuan default template +register_conv_template( + # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555 + # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json + # https://github.com/baichuan-inc/Baichuan-13B/issues/25 + Conversation(name='baichuan', roles=('', ''), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep='')) + +# Mistral template +register_conv_template(Conversation(name='mistral', system_message='', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2='',)) diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index dff0eb6d..7e473ad4 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -1149,21 +1149,24 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i @bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore def generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]: - prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) + use_chat_template: bool = attrs.pop('_format_chat_template', False) + if not use_chat_template: prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) adapter_name = attrs.pop('adapter_name', None) if adapter_name is not None: __self.set_adapter(adapter_name) return self.generate(prompt, **attrs) @bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore def generate_one(__self: _Runnable, prompt: str, stop: list[str], **attrs: t.Any) -> t.Sequence[dict[t.Literal['generated_text'], str]]: - prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) + use_chat_template: bool = attrs.pop('_format_chat_template', False) + if not use_chat_template: prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) adapter_name = attrs.pop('adapter_name', None) if adapter_name is not None: __self.set_adapter(adapter_name) return self.generate_one(prompt, stop, **attrs) @bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore def generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.Generator[str, None, str]: - prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) + use_chat_template: bool = attrs.pop('_format_chat_template', False) + if not use_chat_template: prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) adapter_name = attrs.pop('adapter_name', None) if adapter_name is not None: __self.set_adapter(adapter_name) pre = 0 @@ -1178,6 +1181,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i @bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore async def vllm_generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[list[t.Any], None]: + use_chat_template: bool = attrs.pop('_format_chat_template', False) stop: str | t.Iterable[str] | None = attrs.pop('stop', None) echo = attrs.pop('echo', False) stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None) @@ -1187,7 +1191,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i if adapter_name is not None: __self.set_adapter(adapter_name) request_id: str | None = attrs.pop('request_id', None) if request_id is None: raise ValueError('request_id must not be None.') - prompt, *_ = self.sanitize_parameters(prompt, **attrs) + if not use_chat_template: prompt, *_ = self.sanitize_parameters(prompt, **attrs) if openllm_core.utils.DEBUG: logger.debug('Prompt:\n%s', prompt) if stop_token_ids is None: stop_token_ids = [] @@ -1213,6 +1217,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i @bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore async def vllm_generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[str, None]: + use_chat_template: bool = attrs.pop('_format_chat_template', False) # TODO: System prompt support pre = 0 echo = attrs.pop('echo', False) @@ -1224,7 +1229,7 @@ def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_i if adapter_name is not None: __self.set_adapter(adapter_name) request_id: str | None = attrs.pop('request_id', None) if request_id is None: raise ValueError('request_id must not be None.') - prompt, *_ = self.sanitize_parameters(prompt, **attrs) + if not use_chat_template: prompt, *_ = self.sanitize_parameters(prompt, **attrs) if openllm_core.utils.DEBUG: logger.debug('Prompt:\n%s', prompt) if stop_token_ids is None: stop_token_ids = [] diff --git a/openllm-python/src/openllm/_service.py b/openllm-python/src/openllm/_service.py index 8ebe0e6a..092fe416 100644 --- a/openllm-python/src/openllm/_service.py +++ b/openllm-python/src/openllm/_service.py @@ -66,7 +66,7 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s output=bentoml.io.Text()) async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]: _model = input_dict.get('model', None) - if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model) + if _model != runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model) prompt = input_dict.pop('prompt', None) if prompt is None: raise ValueError("'prompt' should not be None.") stream = input_dict.pop('stream', False) @@ -112,20 +112,12 @@ async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> s )).decode() @svc.api(route='/v1/chat/completions', - input=bentoml.io.JSON.from_sample( - openllm.utils.bentoml_cattr.unstructure( - openllm.openai.ChatCompletionRequest(messages=[{ - 'role': 'system', - 'content': 'You are a helpful assistant.' - }, { - 'role': 'user', - 'content': 'Hello!' - }], model=runner.llm_type))), + input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.ChatCompletionRequest(messages=[{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello!'}], model=runner.llm_type))), output=bentoml.io.Text()) async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]: _model = input_dict.get('model', None) - if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model) - prompt = openllm.openai.messages_to_prompt(input_dict['messages']) + if _model != runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model) + prompt = openllm.openai.messages_to_prompt(input_dict['messages'], model, llm_config) stream = input_dict.pop('stream', False) config = { 'temperature': input_dict.pop('temperature', llm_config['temperature']), @@ -136,6 +128,7 @@ async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) 'max_new_tokens': input_dict.pop('max_tokens', llm_config['max_new_tokens']), 'presence_penalty': input_dict.pop('presence_penalty', llm_config['presence_penalty']), 'frequency_penalty': input_dict.pop('frequency_penalty', llm_config['frequency_penalty']), + '_format_chat_template': True, } async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t.AsyncGenerator[str, None]: diff --git a/openllm-python/src/openllm/protocol/openai.py b/openllm-python/src/openllm/protocol/openai.py index 7f3e3783..9e00f039 100644 --- a/openllm-python/src/openllm/protocol/openai.py +++ b/openllm-python/src/openllm/protocol/openai.py @@ -6,6 +6,8 @@ import attr import openllm_core +from openllm import _conversation + @attr.define class CompletionRequest: prompt: str @@ -131,6 +133,9 @@ class ModelList: object: str = 'list' data: t.List[ModelCard] = attr.field(factory=list) -def messages_to_prompt(messages: list[Message]) -> str: - formatted = '\n'.join([f"{message['role']}: {message['content']}" for message in messages]) - return f'{formatted}\nassistant:' +def messages_to_prompt(messages: list[Message], model: str, llm_config: openllm_core.LLMConfig) -> str: + conv_template = _conversation.get_conv_template(model, llm_config) + for message in messages: + if message['role'] == 'system': conv_template.set_system_message(message['content']) + else: conv_template.append_message(message['role'], message['content']) + return conv_template.get_prompt()