From a0e0f81306d9020b457033033fb095bd5bd8a880 Mon Sep 17 00:00:00 2001 From: MingLiangDai <95396491+MingLiangDai@users.noreply.github.com> Date: Tue, 3 Oct 2023 09:53:37 -0400 Subject: [PATCH] feat: PromptTemplate and system prompt support (#407) Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --- openllm-core/src/openllm_core/__init__.py | 2 + openllm-core/src/openllm_core/_prompt.py | 32 ------------ .../config/configuration_baichuan.py | 13 +++-- .../config/configuration_chatglm.py | 5 ++ .../config/configuration_dolly_v2.py | 5 +- .../config/configuration_falcon.py | 5 +- .../config/configuration_flan_t5.py | 5 +- .../config/configuration_gpt_neox.py | 5 +- .../config/configuration_llama.py | 18 ++++--- .../openllm_core/config/configuration_mpt.py | 5 +- .../openllm_core/config/configuration_opt.py | 7 ++- .../config/configuration_stablelm.py | 5 +- .../config/configuration_starcoder.py | 5 ++ .../src/openllm_core/prompts/__init__.py | 4 ++ .../openllm_core/prompts/prompt_template.py | 49 +++++++++++++++++++ .../src/openllm_core/prompts/utils.py | 18 +++++++ openllm-python/src/openllm/__init__.py | 2 + openllm-python/src/openllm/_llm.py | 38 +++++++++++--- openllm-python/src/openllm/bundle/_package.py | 7 ++- openllm-python/src/openllm/cli/_factory.py | 24 +++++++++ openllm-python/src/openllm/cli/_sdk.py | 12 +++++ openllm-python/src/openllm/cli/entrypoint.py | 13 +++-- .../src/openllm/cli/extension/get_prompt.py | 8 ++- openllm-python/src/openllm/prompts.py | 3 ++ 24 files changed, 227 insertions(+), 63 deletions(-) delete mode 100644 openllm-core/src/openllm_core/_prompt.py create mode 100644 openllm-core/src/openllm_core/prompts/__init__.py create mode 100644 openllm-core/src/openllm_core/prompts/prompt_template.py create mode 100644 openllm-core/src/openllm_core/prompts/utils.py create mode 100644 openllm-python/src/openllm/prompts.py diff --git a/openllm-core/src/openllm_core/__init__.py b/openllm-core/src/openllm_core/__init__.py index 1a43a138..36ed04ef 100644 --- a/openllm-core/src/openllm_core/__init__.py +++ b/openllm-core/src/openllm_core/__init__.py @@ -42,3 +42,5 @@ from .config import MPTConfig as MPTConfig from .config import OPTConfig as OPTConfig from .config import StableLMConfig as StableLMConfig from .config import StarCoderConfig as StarCoderConfig +from .prompts import PromptTemplate as PromptTemplate +from .prompts import process_prompt as process_prompt diff --git a/openllm-core/src/openllm_core/_prompt.py b/openllm-core/src/openllm_core/_prompt.py deleted file mode 100644 index 3fe60d89..00000000 --- a/openllm-core/src/openllm_core/_prompt.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations -import string -import typing as t - -class PromptFormatter(string.Formatter): - """This PromptFormatter is largely based on langchain's implementation.""" - def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any: - if len(args) > 0: raise ValueError('Positional arguments are not supported') - return super().vformat(format_string, args, kwargs) - - def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None: - extras = set(kwargs).difference(used_args) - if extras: raise KeyError(f'Extra params passed: {extras}') - - def extract_template_variables(self, template: str) -> t.Sequence[str]: - return [field[1] for field in self.parse(template) if field[1] is not None] - -default_formatter = PromptFormatter() - -def process_prompt(prompt: str, template: str | None = None, use_prompt_template: bool = True, **attrs: t.Any) -> str: - # Currently, all default prompt will always have `instruction` key. - if not use_prompt_template: return prompt - elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'") - template_variables = default_formatter.extract_template_variables(template) - prompt_variables = {k: v for k, v in attrs.items() if k in template_variables} - if 'instruction' in prompt_variables: - raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'") - try: - return template.format(instruction=prompt, **prompt_variables) - except KeyError as e: - raise RuntimeError( - f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None diff --git a/openllm-core/src/openllm_core/config/configuration_baichuan.py b/openllm-core/src/openllm_core/config/configuration_baichuan.py index c6b97732..0522c54d 100644 --- a/openllm-core/src/openllm_core/config/configuration_baichuan.py +++ b/openllm-core/src/openllm_core/config/configuration_baichuan.py @@ -3,7 +3,7 @@ import typing as t import openllm_core -from openllm_core._prompt import process_prompt +from openllm_core.prompts import PromptTemplate START_BAICHUAN_COMMAND_DOCSTRING = '''\ Run a LLMServer for Baichuan model. @@ -24,7 +24,8 @@ or provide `--model-id` flag when running ``openllm start baichuan``: \b $ openllm start baichuan --model-id='fireballoon/baichuan-vicuna-chinese-7b' ''' -DEFAULT_PROMPT_TEMPLATE = '''{instruction}''' +DEFAULT_SYSTEM_MESSAGE = '' +DEFAULT_PROMPT_TEMPLATE = PromptTemplate('{instruction}') class BaichuanConfig(openllm_core.LLMConfig): """Baichuan-7B is an open-source, large-scale pre-trained language model developed by Baichuan Intelligent Technology. @@ -61,12 +62,16 @@ class BaichuanConfig(openllm_core.LLMConfig): def sanitize_parameters(self, prompt: str, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, max_new_tokens: int | None = None, top_p: float | None = None, temperature: float | None = None, - use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {'max_new_tokens': max_new_tokens, 'top_p': top_p, 'temperature': temperature, **attrs}, {} + system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message + if prompt_template is None: prompt_template = DEFAULT_PROMPT_TEMPLATE + elif isinstance(prompt_template, str): prompt_template = PromptTemplate(template=prompt_template) + return prompt_template.with_options(system_message=system_message).format(instruction=prompt), {'max_new_tokens': max_new_tokens, 'top_p': top_p, 'temperature': temperature, **attrs}, {} def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0] diff --git a/openllm-core/src/openllm_core/config/configuration_chatglm.py b/openllm-core/src/openllm_core/config/configuration_chatglm.py index e5c5e671..5e73a78c 100644 --- a/openllm-core/src/openllm_core/config/configuration_chatglm.py +++ b/openllm-core/src/openllm_core/config/configuration_chatglm.py @@ -5,6 +5,9 @@ import openllm_core from openllm_core.utils import dantic +if t.TYPE_CHECKING: + from openllm_core.prompts import PromptTemplate + START_CHATGLM_COMMAND_DOCSTRING = '''\ Run a LLMServer for ChatGLM model. @@ -61,6 +64,8 @@ class ChatGLMConfig(openllm_core.LLMConfig): def sanitize_parameters(self, prompt: str, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, max_new_tokens: int | None = None, num_beams: int | None = None, top_p: float | None = None, diff --git a/openllm-core/src/openllm_core/config/configuration_dolly_v2.py b/openllm-core/src/openllm_core/config/configuration_dolly_v2.py index 6eeb5473..dbdf4a9b 100644 --- a/openllm-core/src/openllm_core/config/configuration_dolly_v2.py +++ b/openllm-core/src/openllm_core/config/configuration_dolly_v2.py @@ -3,7 +3,8 @@ import typing as t import openllm_core -from openllm_core._prompt import process_prompt +from openllm_core.prompts import PromptTemplate +from openllm_core.prompts import process_prompt from openllm_core.utils import dantic if t.TYPE_CHECKING: @@ -91,6 +92,8 @@ class DollyV2Config(openllm_core.LLMConfig): def sanitize_parameters(self, prompt: str, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, diff --git a/openllm-core/src/openllm_core/config/configuration_falcon.py b/openllm-core/src/openllm_core/config/configuration_falcon.py index bb292486..06c0b2a5 100644 --- a/openllm-core/src/openllm_core/config/configuration_falcon.py +++ b/openllm-core/src/openllm_core/config/configuration_falcon.py @@ -3,7 +3,8 @@ import typing as t import openllm_core -from openllm_core._prompt import process_prompt +from openllm_core.prompts import PromptTemplate +from openllm_core.prompts import process_prompt START_FALCON_COMMAND_DOCSTRING = '''\ Run a LLMServer for FalconLM model. @@ -61,6 +62,8 @@ class FalconConfig(openllm_core.LLMConfig): def sanitize_parameters(self, prompt: str, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, diff --git a/openllm-core/src/openllm_core/config/configuration_flan_t5.py b/openllm-core/src/openllm_core/config/configuration_flan_t5.py index 8d5f2fa0..37964fdc 100644 --- a/openllm-core/src/openllm_core/config/configuration_flan_t5.py +++ b/openllm-core/src/openllm_core/config/configuration_flan_t5.py @@ -3,7 +3,8 @@ import typing as t import openllm_core -from openllm_core._prompt import process_prompt +from openllm_core.prompts import PromptTemplate +from openllm_core.prompts import process_prompt START_FLAN_T5_COMMAND_DOCSTRING = '''\ Run a LLMServer for FLAN-T5 model. @@ -56,6 +57,8 @@ class FlanT5Config(openllm_core.LLMConfig): def sanitize_parameters(self, prompt: str, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, diff --git a/openllm-core/src/openllm_core/config/configuration_gpt_neox.py b/openllm-core/src/openllm_core/config/configuration_gpt_neox.py index 31e1d98a..a330e0fb 100644 --- a/openllm-core/src/openllm_core/config/configuration_gpt_neox.py +++ b/openllm-core/src/openllm_core/config/configuration_gpt_neox.py @@ -3,7 +3,8 @@ import typing as t import openllm_core -from openllm_core._prompt import process_prompt +from openllm_core.prompts import PromptTemplate +from openllm_core.prompts import process_prompt from openllm_core.utils import dantic START_GPT_NEOX_COMMAND_DOCSTRING = '''\ @@ -58,6 +59,8 @@ class GPTNeoXConfig(openllm_core.LLMConfig): def sanitize_parameters(self, prompt: str, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, diff --git a/openllm-core/src/openllm_core/config/configuration_llama.py b/openllm-core/src/openllm_core/config/configuration_llama.py index 1a7b6495..464605e8 100644 --- a/openllm-core/src/openllm_core/config/configuration_llama.py +++ b/openllm-core/src/openllm_core/config/configuration_llama.py @@ -3,7 +3,7 @@ import typing as t import openllm_core -from openllm_core._prompt import process_prompt +from openllm_core.prompts import PromptTemplate from openllm_core.utils import dantic START_LLAMA_COMMAND_DOCSTRING = '''\ @@ -38,7 +38,7 @@ OpenLLM also supports running Llama-2 and its fine-tune and variants. To import \b $ CONVERTER=hf-llama2 openllm import llama /path/to/llama-2 ''' -SYSTEM_MESSAGE = ''' +DEFAULT_SYSTEM_MESSAGE = ''' You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. @@ -47,13 +47,13 @@ SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = '[INST]', '[/INST]', '< str: - return PROMPT_MAPPING[model_type] +def _get_prompt(model_type: t.Literal['v1', 'v2']) -> PromptTemplate: + return PromptTemplate(PROMPT_MAPPING[model_type]) DEFAULT_PROMPT_TEMPLATE = _get_prompt @@ -112,14 +112,18 @@ class LlamaConfig(openllm_core.LLMConfig): def sanitize_parameters(self, prompt: str, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, - use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE('v2' if use_llama2_prompt else 'v1') if use_default_prompt_template else None, use_default_prompt_template, **attrs), { + system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message + if prompt_template is None: prompt_template = DEFAULT_PROMPT_TEMPLATE('v2' if use_llama2_prompt else 'v1') + elif isinstance(prompt_template, str): prompt_template = PromptTemplate(template=prompt_template) + return prompt_template.with_options(system_message=system_message).format(instruction=prompt), { 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k }, {} diff --git a/openllm-core/src/openllm_core/config/configuration_mpt.py b/openllm-core/src/openllm_core/config/configuration_mpt.py index 474e02d2..e000330b 100644 --- a/openllm-core/src/openllm_core/config/configuration_mpt.py +++ b/openllm-core/src/openllm_core/config/configuration_mpt.py @@ -3,7 +3,8 @@ import typing as t import openllm_core -from openllm_core._prompt import process_prompt +from openllm_core.prompts import PromptTemplate +from openllm_core.prompts import process_prompt from openllm_core.utils import dantic MPTPromptType = t.Literal['default', 'instruct', 'chat', 'storywriter'] @@ -86,6 +87,8 @@ class MPTConfig(openllm_core.LLMConfig): def sanitize_parameters(self, prompt: str, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, max_new_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, diff --git a/openllm-core/src/openllm_core/config/configuration_opt.py b/openllm-core/src/openllm_core/config/configuration_opt.py index 9bdfe152..b57f3739 100644 --- a/openllm-core/src/openllm_core/config/configuration_opt.py +++ b/openllm-core/src/openllm_core/config/configuration_opt.py @@ -3,9 +3,12 @@ import typing as t import openllm_core -from openllm_core._prompt import process_prompt +from openllm_core.prompts import process_prompt from openllm_core.utils import dantic +if t.TYPE_CHECKING: + from openllm_core.prompts.prompt_template import PromptTemplate + START_OPT_COMMAND_DOCSTRING = '''\ Run a LLMServer for OPT model. @@ -64,6 +67,8 @@ class OPTConfig(openllm_core.LLMConfig): def sanitize_parameters(self, prompt: str, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, diff --git a/openllm-core/src/openllm_core/config/configuration_stablelm.py b/openllm-core/src/openllm_core/config/configuration_stablelm.py index 53b57715..52c06f1b 100644 --- a/openllm-core/src/openllm_core/config/configuration_stablelm.py +++ b/openllm-core/src/openllm_core/config/configuration_stablelm.py @@ -3,7 +3,8 @@ import typing as t import openllm_core -from openllm_core._prompt import process_prompt +from openllm_core.prompts import PromptTemplate +from openllm_core.prompts import process_prompt START_STABLELM_COMMAND_DOCSTRING = '''\ Run a LLMServer for StableLM model. @@ -62,6 +63,8 @@ class StableLMConfig(openllm_core.LLMConfig): def sanitize_parameters(self, prompt: str, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, temperature: float | None = None, max_new_tokens: int | None = None, top_k: int | None = None, diff --git a/openllm-core/src/openllm_core/config/configuration_starcoder.py b/openllm-core/src/openllm_core/config/configuration_starcoder.py index 05ba414b..178ea619 100644 --- a/openllm-core/src/openllm_core/config/configuration_starcoder.py +++ b/openllm-core/src/openllm_core/config/configuration_starcoder.py @@ -3,6 +3,9 @@ import typing as t import openllm_core +if t.TYPE_CHECKING: + from openllm_core.prompts import PromptTemplate + START_STARCODER_COMMAND_DOCSTRING = '''\ Run a LLMServer for StarCoder model. @@ -54,6 +57,8 @@ class StarCoderConfig(openllm_core.LLMConfig): def sanitize_parameters(self, prompt: str, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, temperature: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, diff --git a/openllm-core/src/openllm_core/prompts/__init__.py b/openllm-core/src/openllm_core/prompts/__init__.py new file mode 100644 index 00000000..6c8fd7f3 --- /dev/null +++ b/openllm-core/src/openllm_core/prompts/__init__.py @@ -0,0 +1,4 @@ +from __future__ import annotations + +from .prompt_template import PromptTemplate as PromptTemplate +from .prompt_template import process_prompt as process_prompt diff --git a/openllm-core/src/openllm_core/prompts/prompt_template.py b/openllm-core/src/openllm_core/prompts/prompt_template.py new file mode 100644 index 00000000..675e67a0 --- /dev/null +++ b/openllm-core/src/openllm_core/prompts/prompt_template.py @@ -0,0 +1,49 @@ +from __future__ import annotations +import typing as t + +import attr + +from .utils import default_formatter + +# equivocal setattr to save one lookup per assignment +_object_setattr = object.__setattr__ + +@attr.define(slots=True) +class PromptTemplate: + template: str + _input_variables: t.Sequence[str] = attr.field(init=False) + + def __attrs_post_init__(self) -> None: + self._input_variables = default_formatter.extract_template_variables(self.template) + + def with_options(self, **attrs: t.Any) -> PromptTemplate: + prompt_variables = {key: '{' + key + '}' if key not in attrs else attrs[key] for key in self._input_variables} + o = attr.evolve(self, template=self.template.format(**prompt_variables)) + _object_setattr(o, '_input_variables', default_formatter.extract_template_variables(o.template)) + return o + + def to_string(self) -> str: + return self.template + + def format(self, **attrs: t.Any) -> str: + prompt_variables = {k: v for k, v in attrs.items() if k in self._input_variables} + try: + return self.template.format(**prompt_variables) + except KeyError as e: + raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template.") from None + +# TODO: remove process_prompt after refactor config for all models +def process_prompt(prompt: str, template: PromptTemplate | str | None = None, use_prompt_template: bool = True, **attrs: t.Any) -> str: + # Currently, all default prompt will always have `instruction` key. + if not use_prompt_template: return prompt + elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'") + if isinstance(template, PromptTemplate): template = template.to_string() + template_variables = default_formatter.extract_template_variables(template) + prompt_variables = {k: v for k, v in attrs.items() if k in template_variables} + if 'instruction' in prompt_variables: + raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'") + try: + return template.format(instruction=prompt, **prompt_variables) + except KeyError as e: + raise RuntimeError( + f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None diff --git a/openllm-core/src/openllm_core/prompts/utils.py b/openllm-core/src/openllm_core/prompts/utils.py new file mode 100644 index 00000000..58bc2f2c --- /dev/null +++ b/openllm-core/src/openllm_core/prompts/utils.py @@ -0,0 +1,18 @@ +from __future__ import annotations +import string +import typing as t + +class PromptFormatter(string.Formatter): + """This PromptFormatter is largely based on langchain's implementation.""" + def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any: + if len(args) > 0: raise ValueError('Positional arguments are not supported') + return super().vformat(format_string, args, kwargs) + + def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None: + extras = set(kwargs).difference(used_args) + if extras: raise KeyError(f'Extra params passed: {extras}') + + def extract_template_variables(self, template: str) -> t.Sequence[str]: + return [field[1] for field in self.parse(template) if field[1] is not None] + +default_formatter = PromptFormatter() diff --git a/openllm-python/src/openllm/__init__.py b/openllm-python/src/openllm/__init__.py index 019427d3..16d7e5e2 100644 --- a/openllm-python/src/openllm/__init__.py +++ b/openllm-python/src/openllm/__init__.py @@ -39,6 +39,7 @@ _import_structure: dict[str, list[str]] = { "bundle": [], "playground": [], "testing": [], + "prompts": ["PromptTemplate"], "utils": ["infer_auto_class"], "serialisation": ["ggml", "transformers"], "cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"], @@ -70,6 +71,7 @@ if _t.TYPE_CHECKING: from .cli._sdk import build as build, import_model as import_model, list_models as list_models, start as start, start_grpc as start_grpc from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES from .serialisation import ggml as ggml, transformers as transformers + from .prompts import PromptTemplate as PromptTemplate from .utils import infer_auto_class as infer_auto_class try: diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 3c318cf3..bd606acd 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -22,7 +22,6 @@ import openllm_core from bentoml._internal.models.model import ModelSignature from openllm_core._configuration import FineTuneConfig from openllm_core._configuration import LLMConfig -from openllm_core._prompt import process_prompt from openllm_core._schema import EmbeddingsOutput from openllm_core._typing_compat import AdaptersMapping from openllm_core._typing_compat import AdaptersTuple @@ -40,6 +39,8 @@ from openllm_core._typing_compat import PeftAdapterOutput from openllm_core._typing_compat import T from openllm_core._typing_compat import TupleAny from openllm_core._typing_compat import overload +from openllm_core.prompts import PromptTemplate +from openllm_core.prompts import process_prompt from openllm_core.utils import DEBUG from openllm_core.utils import MYPY from openllm_core.utils import EnvVarMixin @@ -293,6 +294,8 @@ class LLM(LLMInterface[M, T], ReprMixin): model_version: t.Optional[str], serialisation: LiteralSerialisation, _local: bool, + prompt_template: PromptTemplate | None, + system_message: str | None, **attrs: t.Any) -> None: '''Generated __attrs_init__ for openllm.LLM.''' @@ -310,6 +313,8 @@ class LLM(LLMInterface[M, T], ReprMixin): _model_version: str _serialisation: LiteralSerialisation _local: bool + _prompt_template: PromptTemplate | None + _system_message: str | None def __init_subclass__(cls: type[LLM[M, T]]) -> None: cd = cls.__dict__ @@ -375,6 +380,8 @@ class LLM(LLMInterface[M, T], ReprMixin): def from_pretrained(cls, model_id: str | None = None, model_version: str | None = None, + prompt_template: PromptTemplate | str | None = None, + system_message: str | None = None, llm_config: LLMConfig | None = None, *args: t.Any, quantize: LiteralQuantise | None = None, @@ -421,6 +428,8 @@ class LLM(LLMInterface[M, T], ReprMixin): model_version: Optional version for this given model id. Default to None. This is useful for saving from custom path. If set to None, the version will either be the git hash from given pretrained model, or the hash inferred from last modified time of the given directory. + system_message: Optional system message for what the system prompt for the specified LLM is. If not given, the default system message will be used. + prompt_template: Optional custom prompt template. If not given, the default prompt template for the specified model will be used. llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM will use `config_class` to construct default configuration. quantize: The quantization to use for this LLM. Defaults to None. Possible values @@ -457,6 +466,7 @@ class LLM(LLMInterface[M, T], ReprMixin): if adapter_map is not None and not is_peft_available(): raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'") if adapter_map: logger.debug('OpenLLM will apply the following adapters layers: %s', list(adapter_map)) + if isinstance(prompt_template, str): prompt_template = PromptTemplate(prompt_template) if llm_config is None: llm_config = cls.config_class.model_construct_env(**attrs) @@ -476,6 +486,8 @@ class LLM(LLMInterface[M, T], ReprMixin): quantization_config=quantization_config, _quantize=quantize, _model_version=_tag.version, + _prompt_template=prompt_template, + _system_message=system_message, _tag=_tag, _serialisation=serialisation, _local=_local, @@ -538,6 +550,8 @@ class LLM(LLMInterface[M, T], ReprMixin): _tag: bentoml.Tag, _serialisation: LiteralSerialisation, _local: bool, + _prompt_template: PromptTemplate | None, + _system_message: str | None, _adapters_mapping: AdaptersMapping | None, **attrs: t.Any, ): @@ -650,7 +664,9 @@ class LLM(LLMInterface[M, T], ReprMixin): _adapters_mapping, _model_version, _serialisation, - _local) + _local, + _prompt_template, + _system_message) self.llm_post_init() @@ -728,6 +744,7 @@ class LLM(LLMInterface[M, T], ReprMixin): - The attributes dictionary that can be passed into LLMConfig to generate a GenerationConfig - The attributes dictionary that will be passed into `self.postprocess_generate`. ''' + attrs.update({'prompt_template': self._prompt_template, 'system_message': self._system_message}) return self.config.sanitize_parameters(prompt, **attrs) def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> t.Any: @@ -937,9 +954,10 @@ class LLM(LLMInterface[M, T], ReprMixin): prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **attrs) return self.postprocess_generate(prompt, self.generate(prompt, **generate_kwargs), **postprocess_kwargs) - def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]: - max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device) - src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([])) + def generate_one(self, prompt: str, stop: list[str], **attrs: t.Any) -> list[dict[t.Literal['generated_text'], str]]: + prompt, generate_kwargs, _ = self.sanitize_parameters(prompt, **attrs) + max_new_tokens, encoded_inputs = generate_kwargs.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device) + src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], generate_kwargs.pop('stopping_criteria', openllm.StoppingCriteriaList([])) stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer)) result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:]) # Inference API returns the stop sequence @@ -949,6 +967,7 @@ class LLM(LLMInterface[M, T], ReprMixin): def generate(self, prompt: str, **attrs: t.Any) -> t.List[t.Any]: # TODO: support different generation strategies, similar to self.model.generate + prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) for it in self.generate_iterator(prompt, **attrs): pass return [it] @@ -968,6 +987,7 @@ class LLM(LLMInterface[M, T], ReprMixin): from ._generation import is_partial_stop from ._generation import prepare_logits_processor + prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) len_prompt = len(prompt) config = self.config.model_construct_env(**attrs) if stop_token_ids is None: stop_token_ids = [] @@ -1138,7 +1158,9 @@ def Runner(model_name: str, attrs.update({ 'model_id': llm_config['env']['model_id_value'], 'quantize': llm_config['env']['quantize_value'], - 'serialisation': first_not_none(os.environ.get('OPENLLM_SERIALIZATION'), attrs.get('serialisation'), default=llm_config['serialisation']) + 'serialisation': first_not_none(os.environ.get('OPENLLM_SERIALIZATION'), attrs.get('serialisation'), default=llm_config['serialisation']), + 'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None), + 'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None), }) backend = t.cast(LiteralBackend, first_not_none(backend, default=EnvVarMixin(model_name, backend=llm_config.default_backend() if llm_config is not None else 'pt')['backend_value'])) @@ -1180,24 +1202,28 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate @bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore def __call__(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]: + prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) adapter_name = attrs.pop('adapter_name', None) if adapter_name is not None: __self.set_adapter(adapter_name) return self.generate(prompt, **attrs) @bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore def generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]: + prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) adapter_name = attrs.pop('adapter_name', None) if adapter_name is not None: __self.set_adapter(adapter_name) return self.generate(prompt, **attrs) @bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore def generate_one(__self: _Runnable, prompt: str, stop: list[str], **attrs: t.Any) -> t.Sequence[dict[t.Literal['generated_text'], str]]: + prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) adapter_name = attrs.pop('adapter_name', None) if adapter_name is not None: __self.set_adapter(adapter_name) return self.generate_one(prompt, stop, **attrs) @bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore def generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.Generator[str, None, str]: + prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs) adapter_name = attrs.pop('adapter_name', None) if adapter_name is not None: __self.set_adapter(adapter_name) pre = 0 diff --git a/openllm-python/src/openllm/bundle/_package.py b/openllm-python/src/openllm/bundle/_package.py index 9261c55e..f0172012 100644 --- a/openllm-python/src/openllm/bundle/_package.py +++ b/openllm-python/src/openllm/bundle/_package.py @@ -147,6 +147,8 @@ def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], 'BENTOML_CONFIG_OPTIONS': f"'{environ['BENTOML_CONFIG_OPTIONS']}'", } if adapter_map: env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1') + if llm._system_message: env_dict['OPENLLM_SYSTEM_MESSAGE'] = repr(llm._system_message) + if llm._prompt_template: env_dict['OPENLLM_PROMPT_TEMPLATE'] = repr(llm._prompt_template.to_string()) # We need to handle None separately here, as env from subprocess doesn't accept None value. _env = openllm_core.utils.EnvVarMixin(llm.config['model_name'], quantize=quantize) @@ -212,10 +214,11 @@ def create_bento(bento_tag: bentoml.Tag, container_version_strategy: LiteralContainerVersionStrategy = 'release', _bento_store: BentoStore = Provide[BentoMLContainer.bento_store], _model_store: ModelStore = Provide[BentoMLContainer.model_store]) -> bentoml.Bento: - backend_envvar = llm.config['env']['backend_value'] _serialisation: LiteralSerialisation = openllm_core.utils.first_not_none(serialisation, default=llm.config['serialisation']) labels = dict(llm.identifying_params) - labels.update({'_type': llm.llm_type, '_framework': backend_envvar, 'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle'}) + labels.update({ + '_type': llm.llm_type, '_framework': llm.config['env']['backend_value'], 'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle' + }) if adapter_map: labels.update(adapter_map) if isinstance(workers_per_resource, str): if workers_per_resource == 'round_robin': workers_per_resource = 1.0 diff --git a/openllm-python/src/openllm/cli/_factory.py b/openllm-python/src/openllm/cli/_factory.py index 89a9d1d0..2ff736ca 100644 --- a/openllm-python/src/openllm/cli/_factory.py +++ b/openllm-python/src/openllm/cli/_factory.py @@ -123,6 +123,8 @@ Available official model_id(s): [default: {llm_config['default_id']}] server_timeout: int, model_id: str | None, model_version: str | None, + system_message: str | None, + prompt_template_file: t.IO[t.Any] | None, workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, @@ -175,6 +177,7 @@ Available official model_id(s): [default: {llm_config['default_id']}] start_env = os.environ.copy() start_env = parse_config_options(config, server_timeout, wpr, device, cors, start_env) + prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None start_env.update({ 'OPENLLM_MODEL': model, 'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()), @@ -185,10 +188,14 @@ Available official model_id(s): [default: {llm_config['default_id']}] }) if env['model_id_value']: start_env[env.model_id] = str(env['model_id_value']) if env['quantize_value']: start_env[env.quantize] = str(env['quantize_value']) + if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message + if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template llm = openllm.utils.infer_auto_class(env['backend_value']).for_model(model, model_id=start_env[env.model_id], model_version=model_version, + prompt_template=prompt_template, + system_message=system_message, llm_config=config, ensure_available=True, adapter_map=adapter_map, @@ -233,6 +240,8 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab cog.optgroup.group('General LLM Options', help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."), model_id_option(factory=cog.optgroup), model_version_option(factory=cog.optgroup), + system_message_option(factory=cog.optgroup), + prompt_template_file_option(factory=cog.optgroup), cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'), workers_per_resource_option(factory=cog.optgroup), cors_option(factory=cog.optgroup), @@ -377,6 +386,21 @@ def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option('--model-version', type=click.STRING, default=None, help='Optional model version to save for this model. It will be inferred automatically from model-id.', **attrs)(f) +def system_message_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option('--system-message', + type=click.STRING, + default=None, + envvar='OPENLLM_SYSTEM_MESSAGE', + help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.', + **attrs)(f) + +def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option('--prompt-template-file', + type=click.File(), + default=None, + help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.', + **attrs)(f) + def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: # NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip # XXX: remove the check for __args__ once we have ggml and mlc supports diff --git a/openllm-python/src/openllm/cli/_sdk.py b/openllm-python/src/openllm/cli/_sdk.py index 55391203..cc3856c8 100644 --- a/openllm-python/src/openllm/cli/_sdk.py +++ b/openllm-python/src/openllm/cli/_sdk.py @@ -40,6 +40,8 @@ def _start(model_name: str, workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None, device: tuple[str, ...] | t.Literal['all'] | None = None, quantize: LiteralQuantise | None = None, + system_message: str | None = None, + prompt_template_file: str | None = None, adapter_map: dict[LiteralString, str | None] | None = None, backend: LiteralBackend | None = None, additional_args: list[str] | None = None, @@ -61,6 +63,8 @@ def _start(model_name: str, model_name: The model name to start this LLM model_id: Optional model id for this given LLM timeout: The server timeout + system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message. + prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.. workers_per_resource: Number of workers per resource assigned. See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy) for more information. By default, this is set to 1. @@ -90,6 +94,8 @@ def _start(model_name: str, args: list[str] = [] if model_id: args.extend(['--model-id', model_id]) + if system_message: args.extend(['--system-message', system_message]) + if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)]) if timeout: args.extend(['--server-timeout', str(timeout)]) if workers_per_resource: args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource]) @@ -113,6 +119,8 @@ def _build(model_name: str, bento_version: str | None = None, quantize: LiteralQuantise | None = None, adapter_map: dict[str, str | None] | None = None, + system_message: str | None = None, + prompt_template_file: str | None = None, build_ctx: str | None = None, enable_features: tuple[str, ...] | None = None, workers_per_resource: float | None = None, @@ -137,6 +145,8 @@ def _build(model_name: str, model_id: Optional model id for this given LLM model_version: Optional model version for this given LLM bento_version: Optional bento veresion for this given BentoLLM + system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message. + prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.. quantize: Quantize the model weights. This is only applicable for PyTorch models. Possible quantisation strategies: - int8: Quantize the model with 8bit (bitsandbytes required) @@ -181,6 +191,8 @@ def _build(model_name: str, if enable_features: args.extend([f'--enable-features={f}' for f in enable_features]) if workers_per_resource: args.extend(['--workers-per-resource', str(workers_per_resource)]) if overwrite: args.append('--overwrite') + if system_message: args.extend(['--system-message', system_message]) + if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)]) if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()]) if model_version: args.extend(['--model-version', model_version]) if bento_version: args.extend(['--bento-version', bento_version]) diff --git a/openllm-python/src/openllm/cli/entrypoint.py b/openllm-python/src/openllm/cli/entrypoint.py index 60deb07d..2367e364 100644 --- a/openllm-python/src/openllm/cli/entrypoint.py +++ b/openllm-python/src/openllm/cli/entrypoint.py @@ -101,9 +101,11 @@ from ._factory import model_id_option from ._factory import model_name_argument from ._factory import model_version_option from ._factory import output_option +from ._factory import prompt_template_file_option from ._factory import quantize_option from ._factory import serialisation_option from ._factory import start_command_factory +from ._factory import system_message_option from ._factory import workers_per_resource_option if t.TYPE_CHECKING: @@ -426,6 +428,8 @@ def import_command( @output_option @machine_option @backend_option +@system_message_option +@prompt_template_file_option @click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.') @click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.') @workers_per_resource_option(factory=click, build=True) @@ -478,6 +482,8 @@ def build_command( adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend, + system_message: str | None, + prompt_template_file: t.IO[t.Any] | None, machine: bool, model_version: str | None, dockerfile_template: t.TextIO | None, @@ -514,6 +520,7 @@ def build_command( llm_config = AutoConfig.for_model(model_name) _serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation']) env = EnvVarMixin(model_name, backend=backend, model_id=model_id, quantize=quantize) + prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None # NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError # during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path @@ -521,10 +528,10 @@ def build_command( os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': _serialisation, env.backend: env['backend_value']}) if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value']) if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value']) + if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message + if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template - llm = infer_auto_class(env['backend_value']).for_model( - model_name, model_id=env['model_id_value'], llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs - ) + llm = infer_auto_class(env['backend_value']).for_model(model_name, model_id=env['model_id_value'], prompt_template=prompt_template, system_message=system_message, llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs) labels = dict(llm.identifying_params) labels.update({'_type': llm.llm_type, '_framework': env['backend_value']}) diff --git a/openllm-python/src/openllm/cli/extension/get_prompt.py b/openllm-python/src/openllm/cli/extension/get_prompt.py index f83d7798..f0b393b9 100644 --- a/openllm-python/src/openllm/cli/extension/get_prompt.py +++ b/openllm-python/src/openllm/cli/extension/get_prompt.py @@ -13,7 +13,7 @@ from openllm.cli import termui from openllm.cli._factory import machine_option from openllm.cli._factory import model_complete_envvar from openllm.cli._factory import output_option -from openllm_core._prompt import process_prompt +from openllm_core.prompts import process_prompt LiteralOutput = t.Literal['json', 'pretty', 'porcelain'] @@ -51,7 +51,11 @@ def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, _prompt_template = template(format) else: _prompt_template = template - fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized) + try: + # backward-compatible. TO BE REMOVED once every model has default system message and prompt template. + fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized) + except RuntimeError: + fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(prompt, prompt_template=_prompt_template)[0] if machine: return repr(fully_formatted) elif output == 'porcelain': termui.echo(repr(fully_formatted), fg='white') elif output == 'json': diff --git a/openllm-python/src/openllm/prompts.py b/openllm-python/src/openllm/prompts.py new file mode 100644 index 00000000..35987e98 --- /dev/null +++ b/openllm-python/src/openllm/prompts.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from openllm_core.prompts import PromptTemplate as PromptTemplate