diff --git a/openllm-client/src/openllm_client/_schemas.py b/openllm-client/src/openllm_client/_schemas.py index f57140f5..cbc410c0 100644 --- a/openllm-client/src/openllm_client/_schemas.py +++ b/openllm-client/src/openllm_client/_schemas.py @@ -31,8 +31,6 @@ class Metadata(_SchemaMixin): model_name: str backend: str configuration: t.Dict[str, t.Any] - prompt_template: t.Optional[str] - system_message: t.Optional[str] def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metadata: @@ -49,8 +47,6 @@ def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metada model_name=data['model_name'], backend=data['backend'], configuration=configuration, - prompt_template=data['prompt_template'], - system_message=data['system_message'], ) except Exception as e: raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None diff --git a/openllm-core/src/openllm_core/__init__.py b/openllm-core/src/openllm_core/__init__.py index 9132c8f5..a69c727a 100644 --- a/openllm-core/src/openllm_core/__init__.py +++ b/openllm-core/src/openllm_core/__init__.py @@ -26,4 +26,3 @@ from .config import ( StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig, ) -from .prompts import PromptTemplate as PromptTemplate diff --git a/openllm-core/src/openllm_core/_configuration.py b/openllm-core/src/openllm_core/_configuration.py index d7a286ef..400cb287 100644 --- a/openllm-core/src/openllm_core/_configuration.py +++ b/openllm-core/src/openllm_core/_configuration.py @@ -1460,13 +1460,13 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): def inference_options(self, llm: openllm.LLM[M, T], backend: str | None = None) -> tuple[Self, t.Any]: backend = backend if backend is not None else llm.__llm_backend__ - cls = getattr(self, backend, None) - if cls is None: + framework = getattr(self, backend, None) + if framework is None: raise ValueError(f'Unknown backend {backend}') try: - return self, cls.build(self) + return self, framework.build(self) except AttributeError: - raise RuntimeError(f'Unknown backend {llm.__llm_backend__}') from None + raise RuntimeError(f'Unknown backend {backend}') from None class vllm: @staticmethod @@ -1503,73 +1503,78 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): def build(config: LLMConfig) -> transformers.GenerationConfig: return transformers.GenerationConfig(**converter.unstructure(config.generation_config)) - def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny: - warnings.warn( - "'to_generation_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3 - ) - _, config = self.inference_options(None, 'hf') - return config.to_dict() if return_as_dict else config - - def to_sampling_config(self) -> vllm.SamplingParams: - warnings.warn( - "'to_sampling_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3 - ) - return self.inference_options(None, 'vllm')[-1] + @overload + def compatible_options(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: ... @overload - def with_request(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: ... + def compatible_options(self, request: CohereChatRequest | CohereGenerateRequest) -> dict[str, t.Any]: ... - @overload - def with_request(self, request: CohereChatRequest | CohereGenerateRequest) -> dict[str, t.Any]: ... - - def with_request(self, request: AttrsInstance) -> dict[str, t.Any]: + def compatible_options(self, request: AttrsInstance) -> dict[str, t.Any]: if importlib.util.find_spec('openllm') is None: raise MissingDependencyError( - "'openllm' is required to use 'with_request'. Make sure to install with 'pip install openllm'." + "'openllm' is required to use 'compatible_options'. Make sure to install with 'pip install openllm'." ) from openllm.protocol.cohere import CohereChatRequest, CohereGenerateRequest from openllm.protocol.openai import ChatCompletionRequest, CompletionRequest if isinstance(request, (ChatCompletionRequest, CompletionRequest)): - return self._with_openai_request(request) + return self.openai.build(self, request) elif isinstance(request, (CohereChatRequest, CohereGenerateRequest)): - return self._with_cohere_request(request) + return self.cohere.build(self, request) else: raise TypeError(f'Unknown request type {type(request)}') - def _with_openai_request(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: - d = dict( - temperature=first_not_none(request.temperature, self['temperature']), - top_p=first_not_none(request.top_p, self['top_p']), - top_k=first_not_none(request.top_k, self['top_k']), - best_of=first_not_none(request.best_of, self['best_of']), - n=first_not_none(request.n, default=self['n']), - stop=first_not_none(request.stop, default=None), - max_new_tokens=first_not_none(request.max_tokens, default=self['max_new_tokens']), - presence_penalty=first_not_none(request.presence_penalty, default=self['presence_penalty']), - frequency_penalty=first_not_none(request.frequency_penalty, default=self['frequency_penalty']), - ) - if hasattr(request, 'logprobs'): - d['logprobs'] = first_not_none(request.logprobs, default=self['logprobs']) - return d + class openai: + @staticmethod + def build(config: LLMConfig, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: + d = dict( + temperature=first_not_none(request.temperature, config['temperature']), + top_p=first_not_none(request.top_p, config['top_p']), + top_k=first_not_none(request.top_k, config['top_k']), + best_of=first_not_none(request.best_of, config['best_of']), + n=first_not_none(request.n, default=config['n']), + stop=first_not_none(request.stop, default=None), + max_new_tokens=first_not_none(request.max_tokens, default=config['max_new_tokens']), + presence_penalty=first_not_none(request.presence_penalty, default=config['presence_penalty']), + frequency_penalty=first_not_none(request.frequency_penalty, default=config['frequency_penalty']), + ) + if hasattr(request, 'logprobs'): + d['logprobs'] = first_not_none(request.logprobs, default=config['logprobs']) + return d - def _with_cohere_request(self, request: CohereGenerateRequest | CohereChatRequest) -> dict[str, t.Any]: - d = dict( - max_new_tokens=first_not_none(request.max_tokens, default=self['max_new_tokens']), - temperature=first_not_none(request.temperature, default=self['temperature']), - top_k=first_not_none(request.k, default=self['top_k']), - top_p=first_not_none(request.p, default=self['top_p']), - ) - if hasattr(request, 'num_generations'): - d['n'] = first_not_none(request.num_generations, default=self['n']) - if hasattr(request, 'frequency_penalty'): - d['frequency_penalty'] = first_not_none(request.frequency_penalty, default=self['frequency_penalty']) - if hasattr(request, 'presence_penalty'): - d['presence_penalty'] = first_not_none(request.presence_penalty, default=self['presence_penalty']) - return d + class cohere: + @staticmethod + def build(config: LLMConfig, request: CohereGenerateRequest | CohereChatRequest) -> dict[str, t.Any]: + d = dict( + max_new_tokens=first_not_none(request.max_tokens, default=config['max_new_tokens']), + temperature=first_not_none(request.temperature, default=config['temperature']), + top_k=first_not_none(request.k, default=config['top_k']), + top_p=first_not_none(request.p, default=config['top_p']), + ) + if hasattr(request, 'num_generations'): + d['n'] = first_not_none(request.num_generations, default=config['n']) + if hasattr(request, 'frequency_penalty'): + d['frequency_penalty'] = first_not_none(request.frequency_penalty, default=config['frequency_penalty']) + if hasattr(request, 'presence_penalty'): + d['presence_penalty'] = first_not_none(request.presence_penalty, default=config['presence_penalty']) + return d + + @property + def template(self) -> str: + # Return the default prompt templates for given models. + # This is probably only useful for debugging. Users should be responsible + # for formatting the prompt themselves. + # by default it is just a '{instruction}' + return '{system_message}{instruction}' + + @property + def system_message(self) -> str: + # Return the default system message for given models. + # This should really only be used for chat models. + return '' @classmethod - def to_click_options(cls, f: AnyCallable) -> click.Command: + def parse(cls, f: AnyCallable) -> click.Command: """Convert current configuration to click options. This can be used as a decorator for click commands. @@ -1614,6 +1619,20 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): def peft_task_type(cls) -> str: return PEFT_TASK_TYPE_TARGET_MAPPING[cls.__openllm_model_type__] + # deprecated + def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny: + warnings.warn( + "'to_generation_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3 + ) + _, config = self.inference_options(None, 'hf') + return config.to_dict() if return_as_dict else config + + def to_sampling_config(self) -> vllm.SamplingParams: + warnings.warn( + "'to_sampling_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3 + ) + return self.inference_options(None, 'vllm')[-1] + converter.register_unstructure_hook_factory( lambda cls: lenient_issubclass(cls, LLMConfig), diff --git a/openllm-core/src/openllm_core/_schemas.py b/openllm-core/src/openllm_core/_schemas.py index f8e94994..537e97b7 100644 --- a/openllm-core/src/openllm_core/_schemas.py +++ b/openllm-core/src/openllm_core/_schemas.py @@ -34,8 +34,6 @@ class MetadataOutput(_SchemaMixin): model_name: str backend: str configuration: str - prompt_template: t.Optional[str] - system_message: t.Optional[str] def model_dump(self) -> dict[str, t.Any]: return { @@ -44,8 +42,6 @@ class MetadataOutput(_SchemaMixin): 'model_name': self.model_name, 'backend': self.backend, 'configuration': self.configuration, - 'prompt_template': orjson.dumps(self.prompt_template).decode(), - 'system_message': orjson.dumps(self.system_message).decode(), } diff --git a/openllm-core/src/openllm_core/config/configuration_baichuan.py b/openllm-core/src/openllm_core/config/configuration_baichuan.py index a899e53d..d1c23272 100644 --- a/openllm-core/src/openllm_core/config/configuration_baichuan.py +++ b/openllm-core/src/openllm_core/config/configuration_baichuan.py @@ -1,11 +1,6 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.prompts import PromptTemplate - -DEFAULT_SYSTEM_MESSAGE = '' -DEFAULT_PROMPT_TEMPLATE = PromptTemplate('{instruction}') class BaichuanConfig(openllm_core.LLMConfig): @@ -46,32 +41,3 @@ class BaichuanConfig(openllm_core.LLMConfig): max_new_tokens: int = 2048 top_p: float = 0.7 temperature: float = 0.95 - - @property - def default_prompt_template(self) -> str: - return DEFAULT_PROMPT_TEMPLATE.to_string() - - @property - def default_system_message(self) -> str: - return DEFAULT_SYSTEM_MESSAGE - - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - max_new_tokens: int | None = None, - top_p: float | None = None, - temperature: float | None = None, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message - if prompt_template is None: - prompt_template = DEFAULT_PROMPT_TEMPLATE - elif isinstance(prompt_template, str): - prompt_template = PromptTemplate(template=prompt_template) - return ( - prompt_template.with_options(system_message=system_message).format(instruction=prompt), - {'max_new_tokens': max_new_tokens, 'top_p': top_p, 'temperature': temperature, **attrs}, - {}, - ) diff --git a/openllm-core/src/openllm_core/config/configuration_chatglm.py b/openllm-core/src/openllm_core/config/configuration_chatglm.py index c700ad52..345862ab 100644 --- a/openllm-core/src/openllm_core/config/configuration_chatglm.py +++ b/openllm-core/src/openllm_core/config/configuration_chatglm.py @@ -1,13 +1,6 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.utils import dantic - -if t.TYPE_CHECKING: - from openllm_core.prompts import PromptTemplate - -DEFAULT_PROMPT_TEMPLATE = '{instruction}' class ChatGLMConfig(openllm_core.LLMConfig): @@ -43,40 +36,9 @@ class ChatGLMConfig(openllm_core.LLMConfig): 'thudm/chatglm3-6b', ], } - retain_history: bool = dantic.Field( - False, - description='Whether to retain history given to the model. If set to True, then the model will retain given history.', - ) - use_half_precision: bool = dantic.Field(True, description='Whether to use half precision for model.') class GenerationConfig: max_new_tokens: int = 2048 num_beams: int = 1 top_p: float = 0.7 temperature: float = 0.95 - - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - max_new_tokens: int | None = None, - num_beams: int | None = None, - top_p: float | None = None, - temperature: float | None = None, - chat_history: list[tuple[str, str]] | None = None, - use_default_prompt_template: bool = False, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - prompt_text = '' - if use_default_prompt_template and chat_history is not None: - for i, (old_query, response) in enumerate(chat_history): - prompt_text += f'[Round {i}]\n问:{old_query}\n答:{response}\n' - prompt_text += f'[Round {len(chat_history)}]\n问:{prompt}\n答:' - else: - prompt_text = prompt - return ( - prompt_text, - {'max_new_tokens': max_new_tokens, 'num_beams': num_beams, 'top_p': top_p, 'temperature': temperature, **attrs}, - {}, - ) diff --git a/openllm-core/src/openllm_core/config/configuration_dolly_v2.py b/openllm-core/src/openllm_core/config/configuration_dolly_v2.py index 276aec60..1114c22a 100644 --- a/openllm-core/src/openllm_core/config/configuration_dolly_v2.py +++ b/openllm-core/src/openllm_core/config/configuration_dolly_v2.py @@ -2,8 +2,6 @@ from __future__ import annotations import typing as t import openllm_core -from openllm_core.prompts import PromptTemplate, process_prompt -from openllm_core.utils import dantic if t.TYPE_CHECKING: import transformers @@ -14,32 +12,9 @@ END_KEY = '### End' INTRO_BLURB = ( 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' ) -# NOTE: This is the prompt that is used for generating responses using an already -# trained model. It ends with the response key, where the job of the model is to provide -# the completion that follows it (i.e. the response itself). -DEFAULT_PROMPT_TEMPLATE = """{intro} -{instruction_key} -{instruction} -{response_key} -""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY) def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) -> int: - """Gets the token ID for a given string that has been added to the tokenizer as a special token. - - When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are - treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to. - - Args: - tokenizer: the tokenizer - key: the key to convert to a single token - - Raises: - RuntimeError: if more than one ID was generated - - Returns: - int: the token ID for the given key. - """ token_ids = tokenizer.encode(key) if len(token_ids) > 1: raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}") @@ -66,7 +41,6 @@ class DollyV2Config(openllm_core.LLMConfig): 'default_id': 'databricks/dolly-v2-3b', 'model_ids': ['databricks/dolly-v2-3b', 'databricks/dolly-v2-7b', 'databricks/dolly-v2-12b'], } - return_full_text: bool = dantic.Field(False, description='Whether to return the full prompt to the users.') class GenerationConfig: temperature: float = 0.9 @@ -75,20 +49,8 @@ class DollyV2Config(openllm_core.LLMConfig): max_new_tokens: int = 256 eos_token_id: int = 50277 # NOTE: from get_special_token_id(self.tokenizer, END_KEY) - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - max_new_tokens: int | None = None, - temperature: float | None = None, - top_k: int | None = None, - top_p: float | None = None, - use_default_prompt_template: bool = True, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - return ( - process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), - {'max_new_tokens': max_new_tokens, 'top_k': top_k, 'top_p': top_p, 'temperature': temperature, **attrs}, - {}, + @property + def template(self): + return '{intro}\n{instruction_key}\n{instruction}\n{response_key}\n'.format( + intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY ) diff --git a/openllm-core/src/openllm_core/config/configuration_falcon.py b/openllm-core/src/openllm_core/config/configuration_falcon.py index 8598746d..2557e8f1 100644 --- a/openllm-core/src/openllm_core/config/configuration_falcon.py +++ b/openllm-core/src/openllm_core/config/configuration_falcon.py @@ -1,13 +1,6 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.prompts import PromptTemplate, process_prompt - -DEFAULT_PROMPT_TEMPLATE = """{context} -{user_name}: {instruction} -{agent}: -""" class FalconConfig(openllm_core.LLMConfig): @@ -46,26 +39,6 @@ class FalconConfig(openllm_core.LLMConfig): num_return_sequences: int = 1 num_beams: int = 4 - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - max_new_tokens: int | None = None, - top_k: int | None = None, - num_return_sequences: int | None = None, - eos_token_id: int | None = None, - use_default_prompt_template: bool = False, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - return ( - process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), - { - 'max_new_tokens': max_new_tokens, - 'top_k': top_k, - 'num_return_sequences': num_return_sequences, - 'eos_token_id': eos_token_id, - **attrs, - }, - {}, - ) + @property + def template(self) -> str: + return '{context}\n{user_name}: {instruction}\n{agent}:' diff --git a/openllm-core/src/openllm_core/config/configuration_flan_t5.py b/openllm-core/src/openllm_core/config/configuration_flan_t5.py index 8567db5d..4de691f6 100644 --- a/openllm-core/src/openllm_core/config/configuration_flan_t5.py +++ b/openllm-core/src/openllm_core/config/configuration_flan_t5.py @@ -1,10 +1,6 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.prompts import PromptTemplate, process_prompt - -DEFAULT_PROMPT_TEMPLATE = """Answer the following question:\nQuestion: {instruction}\nAnswer:""" class FlanT5Config(openllm_core.LLMConfig): @@ -38,27 +34,6 @@ class FlanT5Config(openllm_core.LLMConfig): top_p: float = 0.4 repetition_penalty = 1.0 - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - max_new_tokens: int | None = None, - temperature: float | None = None, - top_k: int | None = None, - top_p: float | None = None, - repetition_penalty: float | None = None, - use_default_prompt_template: bool = True, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - return ( - process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), - { - 'max_new_tokens': max_new_tokens, - 'temperature': temperature, - 'top_k': top_k, - 'top_p': top_p, - 'repetition_penalty': repetition_penalty, - }, - {}, - ) + @property + def template(self) -> str: + return 'Answer the following question:\nQuestion: {instruction}\nAnswer:' diff --git a/openllm-core/src/openllm_core/config/configuration_gpt_neox.py b/openllm-core/src/openllm_core/config/configuration_gpt_neox.py index e8dae7da..9ba58ef8 100644 --- a/openllm-core/src/openllm_core/config/configuration_gpt_neox.py +++ b/openllm-core/src/openllm_core/config/configuration_gpt_neox.py @@ -1,11 +1,6 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.prompts import PromptTemplate, process_prompt -from openllm_core.utils import dantic - -DEFAULT_PROMPT_TEMPLATE = """{instruction}""" class GPTNeoXConfig(openllm_core.LLMConfig): @@ -33,24 +28,7 @@ class GPTNeoXConfig(openllm_core.LLMConfig): 'default_id': 'eleutherai/gpt-neox-20b', 'model_ids': ['eleutherai/gpt-neox-20b'], } - use_half_precision: bool = dantic.Field(True, description='Whether to use half precision for model.') class GenerationConfig: temperature: float = 0.9 max_new_tokens: int = 100 - - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - temperature: float | None = None, - max_new_tokens: int | None = None, - use_default_prompt_template: bool = True, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - return ( - process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), - {'max_new_tokens': max_new_tokens, 'temperature': temperature}, - {}, - ) diff --git a/openllm-core/src/openllm_core/config/configuration_llama.py b/openllm-core/src/openllm_core/config/configuration_llama.py index cab33a69..ff88db9f 100644 --- a/openllm-core/src/openllm_core/config/configuration_llama.py +++ b/openllm-core/src/openllm_core/config/configuration_llama.py @@ -1,34 +1,8 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.prompts import PromptTemplate -DEFAULT_SYSTEM_MESSAGE = """ -You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. - -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. -""" SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = '[INST]', '[/INST]', '<>', '', '' -# TODO: support history and v1 prompt implementation -_v1_prompt, _v2_prompt = ( - """{instruction}""", - """{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key}\n""".format( - start_key=SINST_KEY, - sys_key=SYS_KEY, - system_message='{system_message}', - instruction='{instruction}', - end_key=EINST_KEY, - ), -) -PROMPT_MAPPING = {'v1': _v1_prompt, 'v2': _v2_prompt} - - -def _get_prompt(model_type: t.Literal['v1', 'v2']) -> PromptTemplate: - return PromptTemplate(PROMPT_MAPPING[model_type]) - - -DEFAULT_PROMPT_TEMPLATE = _get_prompt class LlamaConfig(openllm_core.LLMConfig): @@ -80,31 +54,15 @@ class LlamaConfig(openllm_core.LLMConfig): presence_penalty: float = 0.5 @property - def default_prompt_template(self) -> str: - return DEFAULT_PROMPT_TEMPLATE('v2').to_string() + def template(self) -> str: + return '{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key}\n'.format( + start_key=SINST_KEY, + sys_key=SYS_KEY, + system_message='{system_message}', + instruction='{instruction}', + end_key=EINST_KEY, + ) @property - def default_system_message(self) -> str: - return DEFAULT_SYSTEM_MESSAGE - - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - top_k: int | None = None, - top_p: float | None = None, - temperature: float | None = None, - max_new_tokens: int | None = None, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message - if prompt_template is None: - prompt_template = DEFAULT_PROMPT_TEMPLATE('v2') - elif isinstance(prompt_template, str): - prompt_template = PromptTemplate(template=prompt_template) - return ( - prompt_template.with_options(system_message=system_message).format(instruction=prompt), - {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k}, - {}, - ) + def system_message(self) -> str: + return "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n" diff --git a/openllm-core/src/openllm_core/config/configuration_mistral.py b/openllm-core/src/openllm_core/config/configuration_mistral.py index 95531aee..e95ab0b7 100644 --- a/openllm-core/src/openllm_core/config/configuration_mistral.py +++ b/openllm-core/src/openllm_core/config/configuration_mistral.py @@ -1,19 +1,8 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.prompts import PromptTemplate -# https://docs.mistral.ai/usage/guardrailing/ -DEFAULT_SYSTEM_MESSAGE = """Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.""" SINST_KEY, EINST_KEY, BOS_TOKEN, EOS_TOKEN = '[INST]', '[/INST]', '', '' -DEFAULT_PROMPT_TEMPLATE = """{start_key}{start_inst} {system_message} {instruction} {end_inst}\n""".format( - start_inst=SINST_KEY, - end_inst=EINST_KEY, - start_key=BOS_TOKEN, - system_message='{system_message}', - instruction='{instruction}', -) class MistralConfig(openllm_core.LLMConfig): @@ -33,8 +22,7 @@ class MistralConfig(openllm_core.LLMConfig): 'default_id': 'mistralai/Mistral-7B-Instruct-v0.1', 'serialisation': 'safetensors', 'backend': ('pt', 'vllm'), - # NOTE: see https://docs.mistral.ai/usage/guardrailing/ - # and https://docs.mistral.ai/llm/mistral-instruct-v0.1 + # NOTE: see https://docs.mistral.ai/usage/guardrailing/ and https://docs.mistral.ai/llm/mistral-instruct-v0.1 'model_ids': [ 'HuggingFaceH4/zephyr-7b-alpha', 'HuggingFaceH4/zephyr-7b-beta', @@ -57,32 +45,16 @@ class MistralConfig(openllm_core.LLMConfig): presence_penalty: float = 0.5 @property - def default_prompt_template(self) -> str: - return DEFAULT_PROMPT_TEMPLATE - - @property - def default_system_message(self) -> str: - return DEFAULT_SYSTEM_MESSAGE - - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - top_k: int | None = None, - top_p: float | None = None, - temperature: float | None = None, - max_new_tokens: int | None = None, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message - if prompt_template is None: - prompt_template = DEFAULT_PROMPT_TEMPLATE - elif isinstance(prompt_template, str): - prompt_template = PromptTemplate(template=prompt_template) - - return ( - prompt_template.with_options(system_message=system_message).format(instruction=prompt), - {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k}, - {}, + def template(self) -> str: + return """{start_key}{start_inst} {system_message} {instruction} {end_inst}\n""".format( + start_inst=SINST_KEY, + end_inst=EINST_KEY, + start_key=BOS_TOKEN, + system_message='{system_message}', + instruction='{instruction}', ) + + # NOTE: https://docs.mistral.ai/usage/guardrailing/ + @property + def system_message(self) -> str: + return """Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.""" diff --git a/openllm-core/src/openllm_core/config/configuration_mpt.py b/openllm-core/src/openllm_core/config/configuration_mpt.py index b0105eda..ee5f730d 100644 --- a/openllm-core/src/openllm_core/config/configuration_mpt.py +++ b/openllm-core/src/openllm_core/config/configuration_mpt.py @@ -1,41 +1,6 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.prompts import PromptTemplate, process_prompt -from openllm_core.utils import dantic - -MPTPromptType = t.Literal['default', 'instruct', 'chat', 'storywriter'] - -INSTRUCTION_KEY, RESPONSE_KEY, END_KEY = '### Instruction:', '### Response:', '### End' -INTRO_BLURB = ( - 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' -) -# NOTE: This is the prompt that is used for generating responses using an already -# trained model. It ends with the response key, where the job of the model is to provide -# the completion that follows it (i.e. the response itself). -_chat_prompt, _default_prompt, _instruct_prompt = ( - """{instruction}""", - """{instruction}""", - """{intro} -{instruction_key} -{instruction} -{response_key} -""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY), -) -PROMPT_MAPPING = { - 'default': _default_prompt, - 'instruct': _instruct_prompt, - 'storywriter': _default_prompt, - 'chat': _chat_prompt, -} - - -def _get_prompt(model_type: str) -> str: - return PROMPT_MAPPING[model_type] - - -DEFAULT_PROMPT_TEMPLATE = _get_prompt class MPTConfig(openllm_core.LLMConfig): @@ -67,46 +32,8 @@ class MPTConfig(openllm_core.LLMConfig): 'mosaicml/mpt-30b-chat', ], } - prompt_type: MPTPromptType = dantic.Field( - '"default"', - description='Given prompt type for running MPT. Default will be inferred from model name if pretrained.', - ) - max_sequence_length: int = dantic.Field( - 2048, - description='Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)', - ) class GenerationConfig: max_new_tokens: int = 128 temperature: float = 0 top_p: float = 0.8 - - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - max_new_tokens: int | None = None, - temperature: float | None = None, - top_p: float | None = None, - prompt_type: MPTPromptType | None = None, - use_default_prompt_template: bool = True, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - _template = None - if use_default_prompt_template: - if prompt_type is None: - if 'instruct' in self.model_id: - prompt_type = 'instruct' - elif 'storywriter' in self.model_id: - prompt_type = 'storywriter' - elif 'chat' in self.model_id: - prompt_type = 'chat' - else: - prompt_type = 'default' - _template = DEFAULT_PROMPT_TEMPLATE(prompt_type) - return ( - process_prompt(prompt, _template, use_default_prompt_template), - {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p}, - {}, - ) diff --git a/openllm-core/src/openllm_core/config/configuration_opt.py b/openllm-core/src/openllm_core/config/configuration_opt.py index 6e22c04d..1676896d 100644 --- a/openllm-core/src/openllm_core/config/configuration_opt.py +++ b/openllm-core/src/openllm_core/config/configuration_opt.py @@ -1,13 +1,6 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.prompts import process_prompt - -if t.TYPE_CHECKING: - from openllm_core.prompts.prompt_template import PromptTemplate - -DEFAULT_PROMPT_TEMPLATE = """{instruction}""" class OPTConfig(openllm_core.LLMConfig): @@ -52,26 +45,3 @@ class OPTConfig(openllm_core.LLMConfig): temperature: float = 0.75 max_new_tokens: int = 256 num_return_sequences: int = 1 - - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - max_new_tokens: int | None = None, - temperature: float | None = None, - top_k: int | None = None, - num_return_sequences: int | None = None, - use_default_prompt_template: bool = False, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - return ( - process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), - { - 'max_new_tokens': max_new_tokens, - 'temperature': temperature, - 'top_k': top_k, - 'num_return_sequences': num_return_sequences, - }, - {}, - ) diff --git a/openllm-core/src/openllm_core/config/configuration_phi.py b/openllm-core/src/openllm_core/config/configuration_phi.py index 31abee79..8c9e4844 100644 --- a/openllm-core/src/openllm_core/config/configuration_phi.py +++ b/openllm-core/src/openllm_core/config/configuration_phi.py @@ -1,12 +1,6 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.prompts import PromptTemplate - -DEFAULT_PROMPT_TEMPLATE = """{instruction}""" - -DEFAULT_SYSTEM_MESSAGE = '' class PhiConfig(openllm_core.LLMConfig): @@ -40,30 +34,3 @@ class PhiConfig(openllm_core.LLMConfig): class SamplingParams: best_of: int = 1 - - @property - def default_prompt_template(self) -> str: - return DEFAULT_PROMPT_TEMPLATE - - @property - def default_system_message(self) -> str: - return DEFAULT_SYSTEM_MESSAGE - - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - max_new_tokens: int | None = None, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message - if prompt_template is None: - prompt_template = PromptTemplate(template=self.default_prompt_template) - elif isinstance(prompt_template, str): - prompt_template = PromptTemplate(template=prompt_template) - return ( - prompt_template.with_options(system_message=system_message).format(instruction=prompt), - {'max_new_tokens': max_new_tokens}, - {}, - ) diff --git a/openllm-core/src/openllm_core/config/configuration_stablelm.py b/openllm-core/src/openllm_core/config/configuration_stablelm.py index 2025787a..6402196c 100644 --- a/openllm-core/src/openllm_core/config/configuration_stablelm.py +++ b/openllm-core/src/openllm_core/config/configuration_stablelm.py @@ -1,16 +1,6 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.prompts import PromptTemplate, process_prompt - -SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version) -- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. -- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. -- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. -- StableLM will refuse to participate in anything that could harm a human. -""" -DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>""" class StableLMConfig(openllm_core.LLMConfig): @@ -47,27 +37,15 @@ class StableLMConfig(openllm_core.LLMConfig): top_k: int = 0 top_p: float = 0.9 - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - temperature: float | None = None, - max_new_tokens: int | None = None, - top_k: int | None = None, - top_p: float | None = None, - use_default_prompt_template: bool = False, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - if 'tuned' in self._model_id and use_default_prompt_template: - system_prompt = attrs.pop('system_prompt', SYSTEM_PROMPT) - prompt_text = process_prompt( - prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs - ) - else: - prompt_text = prompt - return ( - prompt_text, - {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p}, - {}, - ) + @property + def template(self) -> str: + return '{system_message}<|USER|>{instruction}<|ASSISTANT|>' + + @property + def system_message(self) -> str: + return """<|SYSTEM|># StableLM Tuned (Alpha version) +- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. +- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. +- StableLM will refuse to participate in anything that could harm a human. +""" diff --git a/openllm-core/src/openllm_core/config/configuration_starcoder.py b/openllm-core/src/openllm_core/config/configuration_starcoder.py index b752c2ea..8d7aaa37 100644 --- a/openllm-core/src/openllm_core/config/configuration_starcoder.py +++ b/openllm-core/src/openllm_core/config/configuration_starcoder.py @@ -1,21 +1,7 @@ from __future__ import annotations -import typing as t import openllm_core -if t.TYPE_CHECKING: - from openllm_core.prompts import PromptTemplate - -DEFAULT_PROMPT_TEMPLATE = """{instruction}""" -FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, EOD, FIM_INDICATOR = ( - '', - '', - '', - '', - '<|endoftext|>', - '', -) - class StarCoderConfig(openllm_core.LLMConfig): """The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded. @@ -44,37 +30,3 @@ class StarCoderConfig(openllm_core.LLMConfig): top_p: float = 0.95 pad_token_id: int = 49152 repetition_penalty: float = 1.2 - - def sanitize_parameters( - self, - prompt: str, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - temperature: float | None = None, - top_p: float | None = None, - max_new_tokens: int | None = None, - repetition_penalty: float | None = None, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None - if fim_mode: - try: - prefix, suffix = prompt.split(FIM_INDICATOR) - except Exception as err: - raise ValueError(f'Only one {FIM_INDICATOR} allowed in prompt') from err - prompt_text = f'{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}' - else: - prompt_text = prompt - # XXX: This value for pad_token_id is currently a hack, need more investigate why the default starcoder doesn't include the same value as santacoder EOD - return ( - prompt_text, - { - 'temperature': temperature, - 'top_p': top_p, - 'max_new_tokens': max_new_tokens, - 'repetition_penalty': repetition_penalty, - 'pad_token_id': 49152, - **attrs, - }, - {}, - ) diff --git a/openllm-core/src/openllm_core/config/configuration_yi.py b/openllm-core/src/openllm_core/config/configuration_yi.py index ce33f5ed..63512350 100644 --- a/openllm-core/src/openllm_core/config/configuration_yi.py +++ b/openllm-core/src/openllm_core/config/configuration_yi.py @@ -1,11 +1,6 @@ from __future__ import annotations -import typing as t import openllm_core -from openllm_core.prompts import PromptTemplate - -DEFAULT_SYSTEM_MESSAGE = '' -DEFAULT_PROMPT_TEMPLATE = '{instruction}' class YiConfig(openllm_core.LLMConfig): @@ -39,35 +34,3 @@ class YiConfig(openllm_core.LLMConfig): class SamplingParams: best_of: int = 1 presence_penalty: float = 0.5 - - @property - def default_prompt_template(self) -> str: - return DEFAULT_PROMPT_TEMPLATE - - @property - def default_system_message(self) -> str: - return DEFAULT_SYSTEM_MESSAGE - - def sanitize_parameters( - self, - prompt: str, - top_k: int, - top_p: float, - temperature: float, - max_new_tokens: int, - repetition_penalty: float, - prompt_template: PromptTemplate | str | None = None, - system_message: str | None = None, - **attrs: t.Any, - ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message - if prompt_template is None: - prompt_template = DEFAULT_PROMPT_TEMPLATE - elif isinstance(prompt_template, str): - prompt_template = PromptTemplate(template=prompt_template) - - return ( - prompt_template.with_options(system_message=system_message).format(instruction=prompt), - {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k}, - {}, - ) diff --git a/openllm-core/src/openllm_core/prompts/__init__.py b/openllm-core/src/openllm_core/prompts/__init__.py deleted file mode 100644 index c9b7472a..00000000 --- a/openllm-core/src/openllm_core/prompts/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from __future__ import annotations - -from .prompt_template import PromptTemplate as PromptTemplate, process_prompt as process_prompt diff --git a/openllm-core/src/openllm_core/prompts/prompt_template.py b/openllm-core/src/openllm_core/prompts/prompt_template.py deleted file mode 100644 index e9a31764..00000000 --- a/openllm-core/src/openllm_core/prompts/prompt_template.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations -import typing as t - -import attr - -from .utils import default_formatter - -# equivocal setattr to save one lookup per assignment -_object_setattr = object.__setattr__ - - -@attr.define(slots=True) -class PromptTemplate: - template: str - _input_variables: t.Sequence[str] = attr.field(init=False) - - def __attrs_post_init__(self) -> None: - self._input_variables = default_formatter.extract_template_variables(self.template) - - def with_options(self, **attrs: t.Any) -> PromptTemplate: - prompt_variables = {key: '{' + key + '}' if key not in attrs else attrs[key] for key in self._input_variables} - o = attr.evolve(self, template=self.template.format(**prompt_variables)) - _object_setattr(o, '_input_variables', default_formatter.extract_template_variables(o.template)) - return o - - def to_string(self) -> str: - return self.template - - def format(self, **attrs: t.Any) -> str: - prompt_variables = {k: v for k, v in attrs.items() if k in self._input_variables} - try: - return self.template.format(**prompt_variables) - except KeyError as e: - raise RuntimeError( - f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template." - ) from None - - -# TODO: remove process_prompt after refactor config for all models -def process_prompt( - prompt: str, template: PromptTemplate | str | None = None, use_prompt_template: bool = True, **attrs: t.Any -) -> str: - # Currently, all default prompt will always have `instruction` key. - if not use_prompt_template: - return prompt - elif template is None: - raise ValueError("'template' can't be None while 'use_prompt_template=False'") - if isinstance(template, PromptTemplate): - template = template.to_string() - template_variables = default_formatter.extract_template_variables(template) - prompt_variables = {k: v for k, v in attrs.items() if k in template_variables} - if 'instruction' in prompt_variables: - raise RuntimeError( - "'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'" - ) - try: - return template.format(instruction=prompt, **prompt_variables) - except KeyError as e: - raise RuntimeError( - f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template." - ) from None diff --git a/openllm-core/src/openllm_core/prompts/utils.py b/openllm-core/src/openllm_core/prompts/utils.py deleted file mode 100644 index 423820ba..00000000 --- a/openllm-core/src/openllm_core/prompts/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations -import string -import typing as t - - -class PromptFormatter(string.Formatter): - """This PromptFormatter is largely based on langchain's implementation.""" - - def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any: - if len(args) > 0: - raise ValueError('Positional arguments are not supported') - return super().vformat(format_string, args, kwargs) - - def check_unused_args( - self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any] - ) -> None: - extras = set(kwargs).difference(used_args) - if extras: - raise KeyError(f'Extra params passed: {extras}') - - def extract_template_variables(self, template: str) -> t.Sequence[str]: - return [field[1] for field in self.parse(template) if field[1] is not None] - - -default_formatter = PromptFormatter() diff --git a/openllm-core/src/openllm_core/utils/__init__.py b/openllm-core/src/openllm_core/utils/__init__.py index 2c7c6325..faa92acb 100644 --- a/openllm-core/src/openllm_core/utils/__init__.py +++ b/openllm-core/src/openllm_core/utils/__init__.py @@ -115,7 +115,7 @@ def calc_dir_size(path:PathType)->int:return sum(f.stat().st_size for f in _Path @functools.lru_cache(maxsize=128) def generate_hash_from_file(f:str,algorithm:t.Literal['md5','sha1']='sha1')->str:return str(getattr(hashlib,algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()) def getenv(env:str,default:t.Any=None,var:t.Sequence[str]|None=None)->t.Any: - env_key={f'OPENLLM_{env.upper()}',env.upper()} + env_key={env.upper(),f'OPENLLM_{env.upper()}'} if var is not None:env_key=set(var)|env_key def callback(k:str)->t.Any: _var = os.getenv(k) diff --git a/openllm-python/src/openllm/_deprecated.py b/openllm-python/src/openllm/_deprecated.py index 3743a09c..b3564568 100644 --- a/openllm-python/src/openllm/_deprecated.py +++ b/openllm-python/src/openllm/_deprecated.py @@ -77,8 +77,6 @@ def Runner( 'serialisation': first_not_none( attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation'] ), - 'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None), - 'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None), } ) diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 90de744a..55edfd91 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -25,7 +25,6 @@ from openllm_core._typing_compat import ( TupleAny, ) from openllm_core.exceptions import MissingDependencyError -from openllm_core.prompts import PromptTemplate from openllm_core.utils import ( DEBUG, ENV_VARS_TRUE_VALUES, @@ -129,8 +128,6 @@ class LLM(t.Generic[M, T], ReprMixin): _serialisation: LiteralSerialisation _local: bool _max_model_len: int | None - _prompt_template: PromptTemplate | None - _system_message: str | None __llm_dtype__: LiteralDtype | t.Literal['auto', 'half', 'float'] = 'auto' __llm_torch_dtype__: 'torch.dtype' = None @@ -148,8 +145,6 @@ class LLM(t.Generic[M, T], ReprMixin): model_id, model_version=None, model_tag=None, - prompt_template=None, - system_message=None, llm_config=None, backend=None, *args, @@ -192,8 +187,6 @@ class LLM(t.Generic[M, T], ReprMixin): serialisation=serialisation, local=_local, max_model_len=max_model_len, - prompt_template=PromptTemplate(prompt_template) if isinstance(prompt_template, str) else prompt_template, - system_message=system_message, LLM__model_attrs=model_attrs, LLM__tokenizer_attrs=tokenizer_attrs, llm_dtype__=dtype.lower(), @@ -480,16 +473,24 @@ class LLM(t.Generic[M, T], ReprMixin): request_id = gen_random_uuid() if request_id is None else request_id previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n'] - async for out in self.runner.generate_iterator.async_stream( - prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True) - ): - generated = GenerationOutput.from_runner(out).with_options(prompt=prompt) - delta_outputs = [None] * len(generated.outputs) - if generated.finished: - break - for output in generated.outputs: - i = output.index - delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :] - previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids) - delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens) - yield generated.with_options(outputs=delta_outputs) + try: + generator = self.runner.generate_iterator.async_stream( + prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True) + ) + except Exception as err: + raise RuntimeError(f'Failed to start generation task: {err}') from err + + try: + async for out in generator: + generated = GenerationOutput.from_runner(out).with_options(prompt=prompt) + delta_outputs = [None] * len(generated.outputs) + if generated.finished: + break + for output in generated.outputs: + i = output.index + delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :] + previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids) + delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens) + yield generated.with_options(outputs=delta_outputs) + except Exception as err: + raise RuntimeError(f'Exception caught during generation: {err}') from err diff --git a/openllm-python/src/openllm/_llm.pyi b/openllm-python/src/openllm/_llm.pyi index 431410fe..e47a4294 100644 --- a/openllm-python/src/openllm/_llm.pyi +++ b/openllm-python/src/openllm/_llm.pyi @@ -18,7 +18,6 @@ from openllm_core._typing_compat import ( M, T, ) -from openllm_core.prompts import PromptTemplate from openllm_core.utils.representation import ReprArgs from ._quantisation import QuantizationConfig @@ -48,8 +47,6 @@ class LLM(Generic[M, T]): _adapter_map: Optional[AdapterMap] _serialisation: LiteralSerialisation _local: bool - _prompt_template: Optional[PromptTemplate] - _system_message: Optional[str] __llm_dtype__: Dtype = ... __llm_torch_dtype__: Optional[torch.dtype] = ... @@ -74,8 +71,6 @@ class LLM(Generic[M, T]): model_id: str, model_version: Optional[str] = ..., model_tag: Optional[Union[str, Tag]] = ..., - prompt_template: Optional[Union[str, PromptTemplate]] = ..., - system_message: Optional[str] = ..., llm_config: Optional[LLMConfig] = ..., backend: Optional[LiteralBackend] = ..., *args: Any, diff --git a/openllm-python/src/openllm/_runners.py b/openllm-python/src/openllm/_runners.py index c9de4ceb..a4293534 100644 --- a/openllm-python/src/openllm/_runners.py +++ b/openllm-python/src/openllm/_runners.py @@ -27,7 +27,7 @@ def registry(cls=None, *, alias=None): return decorator(cls) -def runner(llm): +def runner(llm: openllm.LLM): from ._strategies import CascadingResourceStrategy try: @@ -35,19 +35,6 @@ def runner(llm): except bentoml.exceptions.NotFound as err: raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err - if llm._prompt_template: - prompt_template = llm._prompt_template.to_string() - elif hasattr(llm.config, 'default_prompt_template'): - prompt_template = llm.config.default_prompt_template - else: - prompt_template = None - if llm._system_message: - system_message = llm._system_message - elif hasattr(llm.config, 'default_system_message'): - system_message = llm.config.default_system_message - else: - system_message = None - return types.new_class( llm.config.__class__.__name__[:-6] + 'Runner', (bentoml.Runner,), @@ -80,8 +67,8 @@ def runner(llm): ('llm_tag', llm.tag), ), 'has_adapters': llm.has_adapters, - 'prompt_template': prompt_template, - 'system_message': system_message, + 'template': llm.config.template, + 'system_message': llm.config.system_message, } ), )( @@ -101,7 +88,7 @@ class CTranslateRunnable(bentoml.Runnable): def __init__(self, llm): if not is_ctranslate_available(): raise OpenLLMException('ctranslate is not installed. Please install it with `pip install "openllm[ctranslate]"`') - self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.modle, llm.tokenizer + self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): diff --git a/openllm-python/src/openllm/_runners.pyi b/openllm-python/src/openllm/_runners.pyi index 1621fd8b..681d8e57 100644 --- a/openllm-python/src/openllm/_runners.pyi +++ b/openllm-python/src/openllm/_runners.pyi @@ -83,8 +83,8 @@ class Runner(Protocol[Mo, To]): config: LLMConfig = ... backend: LiteralBackend = ... has_adapters: bool = ... - prompt_template: Optional[str] = ... - system_message: Optional[str] = ... + template: str = ... + system_message: str = ... class generate_iterator(RunnerMethod[List[int], AsyncGenerator[str, None]]): @staticmethod diff --git a/openllm-python/src/openllm/_service.py b/openllm-python/src/openllm/_service.py index 72d60d8a..c1bb5273 100644 --- a/openllm-python/src/openllm/_service.py +++ b/openllm-python/src/openllm/_service.py @@ -13,8 +13,6 @@ logger = logging.getLogger(__name__) llm = openllm.LLM[t.Any, t.Any]( model_id=svars.model_id, model_tag=svars.model_tag, - prompt_template=svars.prompt_template, - system_message=svars.system_message, serialisation=svars.serialization, adapter_map=svars.adapter_map, trust_remote_code=svars.trust_remote_code, @@ -50,8 +48,6 @@ _Metadata = openllm.MetadataOutput( backend=llm.__llm_backend__, model_id=llm.model_id, configuration=llm.config.model_dump_json().decode(), - prompt_template=llm.runner.prompt_template, - system_message=llm.runner.system_message, ) diff --git a/openllm-python/src/openllm/_service_vars.py b/openllm-python/src/openllm/_service_vars.py index 05900ff9..6cd5df40 100644 --- a/openllm-python/src/openllm/_service_vars.py +++ b/openllm-python/src/openllm/_service_vars.py @@ -7,7 +7,5 @@ from openllm_core.utils import ENV_VARS_TRUE_VALUES model_id = os.environ['OPENLLM_MODEL_ID'] model_tag = None adapter_map = orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None))) -prompt_template = os.getenv('OPENLLM_PROMPT_TEMPLATE') -system_message = os.getenv('OPENLLM_SYSTEM_MESSAGE') serialization = os.getenv('OPENLLM_SERIALIZATION', default='safetensors') trust_remote_code = str(os.getenv('TRUST_REMOTE_CODE', default=str(False))).upper() in ENV_VARS_TRUE_VALUES diff --git a/openllm-python/src/openllm/_service_vars.pyi b/openllm-python/src/openllm/_service_vars.pyi index bf89944c..01859112 100644 --- a/openllm-python/src/openllm/_service_vars.pyi +++ b/openllm-python/src/openllm/_service_vars.pyi @@ -5,7 +5,5 @@ from openllm_core._typing_compat import LiteralSerialisation model_id: str = ... model_tag: Optional[str] = ... adapter_map: Optional[Dict[str, str]] = ... -prompt_template: Optional[str] = ... -system_message: Optional[str] = ... serialization: LiteralSerialisation = ... trust_remote_code: bool = ... diff --git a/openllm-python/src/openllm/_service_vars_pkg.py b/openllm-python/src/openllm/_service_vars_pkg.py index 8b1b1d97..7cf5c203 100644 --- a/openllm-python/src/openllm/_service_vars_pkg.py +++ b/openllm-python/src/openllm/_service_vars_pkg.py @@ -4,6 +4,4 @@ model_id = '{__model_id__}' # openllm: model id model_tag = '{__model_tag__}' # openllm: model tag adapter_map = orjson.loads("""{__model_adapter_map__}""") # openllm: model adapter map serialization = '{__model_serialization__}' # openllm: model serialization -prompt_template = {__model_prompt_template__} # openllm: model prompt template -system_message = {__model_system_message__} # openllm: model system message trust_remote_code = {__model_trust_remote_code__} # openllm: model trust remote code diff --git a/openllm-python/src/openllm/bundle/_package.py b/openllm-python/src/openllm/bundle/_package.py index 98fc2ced..81edce30 100644 --- a/openllm-python/src/openllm/bundle/_package.py +++ b/openllm-python/src/openllm/bundle/_package.py @@ -83,8 +83,6 @@ def construct_docker_options( None, llm._serialisation, llm, - llm._system_message, - llm._prompt_template, use_current_env=False, ) # XXX: We need to quote this so that the envvar in container recognize as valid json @@ -101,8 +99,6 @@ def construct_docker_options( OPENLLM_MODEL_ID = '# openllm: model id' OPENLLM_MODEL_TAG = '# openllm: model tag' OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map' -OPENLLM_MODEL_PROMPT_TEMPLATE = '# openllm: model prompt template' -OPENLLM_MODEL_SYSTEM_MESSAGE = '# openllm: model system message' OPENLLM_MODEL_SERIALIZATION = '# openllm: model serialization' OPENLLM_MODEL_TRUST_REMOTE_CODE = '# openllm: model trust remote code' @@ -140,16 +136,6 @@ class ModelAdapterMapFormatter(_ServiceVarsFormatter): identifier = OPENLLM_MODEL_ADAPTER_MAP -class ModelPromptTemplateFormatter(_ServiceVarsFormatter): - keyword = '__model_prompt_template__' - identifier = OPENLLM_MODEL_PROMPT_TEMPLATE - - -class ModelSystemMessageFormatter(_ServiceVarsFormatter): - keyword = '__model_system_message__' - identifier = OPENLLM_MODEL_SYSTEM_MESSAGE - - class ModelSerializationFormatter(_ServiceVarsFormatter): keyword = '__model_serialization__' identifier = OPENLLM_MODEL_SERIALIZATION @@ -187,16 +173,6 @@ def write_service(llm, llm_fs, adapter_map): src_contents[i] = serialization_formatter.parse_line(it) elif trust_remote_code_formatter.identifier in it: src_contents[i] = trust_remote_code_formatter.parse_line(it) - elif OPENLLM_MODEL_PROMPT_TEMPLATE in it: - if llm._prompt_template: - src_contents[i] = ModelPromptTemplateFormatter(f'"""{llm._prompt_template.to_string()}"""').parse_line(it) - else: - src_contents[i] = ModelPromptTemplateFormatter(str(None)).parse_line(it) - elif OPENLLM_MODEL_SYSTEM_MESSAGE in it: - if llm._system_message: - src_contents[i] = ModelSystemMessageFormatter(f'"""{llm._system_message}"""').parse_line(it) - else: - src_contents[i] = ModelSystemMessageFormatter(str(None)).parse_line(it) script = f"# GENERATED BY 'openllm build {llm.model_id}'. DO NOT EDIT\n\n" + ''.join(src_contents) if SHOW_CODEGEN: diff --git a/openllm-python/src/openllm/entrypoints/cohere.py b/openllm-python/src/openllm/entrypoints/cohere.py index dc52bfa4..7197b54d 100644 --- a/openllm-python/src/openllm/entrypoints/cohere.py +++ b/openllm-python/src/openllm/entrypoints/cohere.py @@ -105,7 +105,7 @@ async def cohere_generate(req, llm): if err_check is not None: return err_check request_id = gen_random_uuid('cohere-generate') - config = llm.config.with_request(request) + config = llm.config.compatible_options(request) if request.prompt_vars is not None: prompt = request.prompt.format(**request.prompt_vars) @@ -202,7 +202,7 @@ async def cohere_chat(req, llm): _transpile_cohere_chat_messages(request), tokenize=False, add_generation_prompt=llm.config['add_generation_prompt'] ) logger.debug('Prompt: %r', prompt) - config = llm.config.with_request(request) + config = llm.config.compatible_options(request) try: result_generator = llm.generate_iterator(prompt, request_id=request_id, **config) diff --git a/openllm-python/src/openllm/entrypoints/openai.py b/openllm-python/src/openllm/entrypoints/openai.py index 3bd1ffbc..835e78c6 100644 --- a/openllm-python/src/openllm/entrypoints/openai.py +++ b/openllm-python/src/openllm/entrypoints/openai.py @@ -156,7 +156,7 @@ async def chat_completions(req, llm): request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt'] ) logger.debug('Prompt: %r', prompt) - config = llm.config.with_request(request) + config = llm.config.compatible_options(request) try: result_generator = llm.generate_iterator(prompt, request_id=request_id, **config) @@ -272,14 +272,9 @@ async def completions(req, llm): prompt = request.prompt # TODO: Support multiple prompts - if request.logprobs is not None and llm.__llm_backend__ == 'pt': # TODO: support logprobs generation for PyTorch - return error_response( - HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`." - ) - model_name, request_id = request.model, gen_random_uuid('cmpl') created_time = int(time.monotonic()) - config = llm.config.with_request(request) + config = llm.config.compatible_options(request) try: result_generator = llm.generate_iterator(prompt, request_id=request_id, **config) diff --git a/openllm-python/src/openllm_cli/_factory.py b/openllm-python/src/openllm_cli/_factory.py index b7bd68ed..3ebd8374 100644 --- a/openllm-python/src/openllm_cli/_factory.py +++ b/openllm-python/src/openllm_cli/_factory.py @@ -135,13 +135,11 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ... def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]: def wrapper(fn: FC) -> t.Callable[[FC], FC]: composed = openllm.utils.compose( - _OpenLLM_GenericInternalConfig().to_click_options, + _OpenLLM_GenericInternalConfig.parse, _http_server_args if not serve_grpc else _grpc_server_args, cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'), dtype_option(factory=cog.optgroup), model_version_option(factory=cog.optgroup), - system_message_option(factory=cog.optgroup), - prompt_template_file_option(factory=cog.optgroup), cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'), workers_per_resource_option(factory=cog.optgroup), cors_option(factory=cog.optgroup), @@ -318,27 +316,6 @@ def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal )(f) -def system_message_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: - return cli_option( - '--system-message', - type=click.STRING, - default=None, - envvar='OPENLLM_SYSTEM_MESSAGE', - help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.', - **attrs, - )(f) - - -def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: - return cli_option( - '--prompt-template-file', - type=click.File(), - default=None, - help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.', - **attrs, - )(f) - - def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( '--backend', diff --git a/openllm-python/src/openllm_cli/_sdk.py b/openllm-python/src/openllm_cli/_sdk.py index 555003cf..d868cbff 100644 --- a/openllm-python/src/openllm_cli/_sdk.py +++ b/openllm-python/src/openllm_cli/_sdk.py @@ -37,8 +37,6 @@ def _start( workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None, device: tuple[str, ...] | t.Literal['all'] | None = None, quantize: LiteralQuantise | None = None, - system_message: str | None = None, - prompt_template_file: str | None = None, adapter_map: dict[LiteralString, str | None] | None = None, backend: LiteralBackend | None = None, additional_args: list[str] | None = None, @@ -60,8 +58,6 @@ def _start( Args: model_id: The model id to start this LLMServer timeout: The server timeout - system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message. - prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.. workers_per_resource: Number of workers per resource assigned. See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy) for more information. By default, this is set to 1. @@ -88,10 +84,6 @@ def _start( os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt') args: list[str] = [model_id] - if system_message: - args.extend(['--system-message', system_message]) - if prompt_template_file: - args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)]) if timeout: args.extend(['--server-timeout', str(timeout)]) if workers_per_resource: @@ -129,8 +121,6 @@ def _build( bento_version: str | None = None, quantize: LiteralQuantise | None = None, adapter_map: dict[str, str | None] | None = None, - system_message: str | None = None, - prompt_template_file: str | None = None, build_ctx: str | None = None, enable_features: tuple[str, ...] | None = None, dockerfile_template: str | None = None, @@ -155,8 +145,6 @@ def _build( model_id: The model id to build this BentoLLM model_version: Optional model version for this given LLM bento_version: Optional bento veresion for this given BentoLLM - system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message. - prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.. quantize: Quantize the model weights. This is only applicable for PyTorch models. Possible quantisation strategies: - int8: Quantize the model with 8bit (bitsandbytes required) @@ -209,10 +197,6 @@ def _build( args.extend([f'--enable-features={f}' for f in enable_features]) if overwrite: args.append('--overwrite') - if system_message: - args.extend(['--system-message', system_message]) - if prompt_template_file: - args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)]) if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()]) if model_version: diff --git a/openllm-python/src/openllm_cli/entrypoint.py b/openllm-python/src/openllm_cli/entrypoint.py index 8cfadcd8..86b31185 100644 --- a/openllm-python/src/openllm_cli/entrypoint.py +++ b/openllm-python/src/openllm_cli/entrypoint.py @@ -100,11 +100,9 @@ from ._factory import ( model_name_argument, model_version_option, parse_config_options, - prompt_template_file_option, quantize_option, serialisation_option, start_decorator, - system_message_option, ) if t.TYPE_CHECKING: @@ -404,8 +402,6 @@ def start_command( model_id: str, server_timeout: int, model_version: str | None, - system_message: str | None, - prompt_template_file: t.IO[t.Any] | None, workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, @@ -437,7 +433,6 @@ def start_command( ) adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None) - prompt_template = prompt_template_file.read() if prompt_template_file is not None else None from openllm.serialisation.transformers.weights import has_safetensors_weights @@ -467,8 +462,6 @@ def start_command( llm = openllm.LLM[t.Any, t.Any]( model_id=model_id, model_version=model_version, - prompt_template=prompt_template, - system_message=system_message, backend=backend, adapter_map=adapter_map, quantize=quantize, @@ -495,8 +488,6 @@ def start_command( adapter_map, serialisation, llm, - system_message, - prompt_template, ) server = bentoml.HTTPServer('_service:svc', **server_attrs) @@ -541,8 +532,6 @@ def start_grpc_command( model_id: str, server_timeout: int, model_version: str | None, - system_message: str | None, - prompt_template_file: t.IO[t.Any] | None, workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, @@ -577,7 +566,6 @@ def start_grpc_command( ) adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None) - prompt_template = prompt_template_file.read() if prompt_template_file is not None else None from openllm.serialisation.transformers.weights import has_safetensors_weights @@ -604,8 +592,6 @@ def start_grpc_command( llm = openllm.LLM[t.Any, t.Any]( model_id=model_id, model_version=model_version, - prompt_template=prompt_template, - system_message=system_message, backend=backend, adapter_map=adapter_map, quantize=quantize, @@ -634,8 +620,6 @@ def start_grpc_command( adapter_map, serialisation, llm, - system_message, - prompt_template, ) server = bentoml.GrpcServer('_service:svc', **server_attrs) @@ -654,18 +638,7 @@ def start_grpc_command( def process_environ( - config, - server_timeout, - wpr, - device, - cors, - model_id, - adapter_map, - serialisation, - llm, - system_message, - prompt_template, - use_current_env=True, + config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, use_current_env=True ) -> t.Dict[str, t.Any]: environ = parse_config_options( config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {} @@ -685,10 +658,6 @@ def process_environ( ) if llm.quantise: environ['QUANTIZE'] = str(llm.quantise) - if system_message: - environ['OPENLLM_SYSTEM_MESSAGE'] = system_message - if prompt_template: - environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template return environ @@ -929,8 +898,6 @@ class BuildBentoOutput(t.TypedDict): ) @dtype_option @backend_option -@system_message_option -@prompt_template_file_option @click.option( '--bento-version', type=str, @@ -1004,8 +971,6 @@ def build_command( adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend | None, - system_message: str | None, - prompt_template_file: t.IO[t.Any] | None, model_version: str | None, dockerfile_template: t.TextIO | None, containerize: bool, @@ -1051,12 +1016,9 @@ def build_command( state = ItemState.NOT_FOUND - prompt_template = prompt_template_file.read() if prompt_template_file is not None else None llm = openllm.LLM[t.Any, t.Any]( model_id=model_id, model_version=model_version, - prompt_template=prompt_template, - system_message=system_message, backend=backend, quantize=quantize, dtype=dtype, @@ -1075,19 +1037,7 @@ def build_command( llm._tag = model.tag os.environ.update( - **process_environ( - llm.config, - llm.config['timeout'], - 1.0, - None, - True, - llm.model_id, - None, - llm._serialisation, - llm, - llm._system_message, - llm._prompt_template, - ) + **process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm) ) try: diff --git a/openllm-python/src/openllm_cli/extension/get_prompt.py b/openllm-python/src/openllm_cli/extension/get_prompt.py index 57e9f310..0e64c230 100644 --- a/openllm-python/src/openllm_cli/extension/get_prompt.py +++ b/openllm-python/src/openllm_cli/extension/get_prompt.py @@ -1,33 +1,82 @@ from __future__ import annotations import logging -import traceback +import string import typing as t +import attr import click import inflection import orjson from bentoml_cli.utils import opt_callback import openllm -import openllm_core from openllm_cli import termui from openllm_cli._factory import model_complete_envvar -from openllm_core.prompts import process_prompt logger = logging.getLogger(__name__) +class PromptFormatter(string.Formatter): + def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any: + if len(args) > 0: + raise ValueError('Positional arguments are not supported') + return super().vformat(format_string, args, kwargs) + + def check_unused_args( + self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any] + ) -> None: + extras = set(kwargs).difference(used_args) + if extras: + raise KeyError(f'Extra params passed: {extras}') + + def extract_template_variables(self, template: str) -> t.Sequence[str]: + return [field[1] for field in self.parse(template) if field[1] is not None] + + +default_formatter = PromptFormatter() + +# equivocal setattr to save one lookup per assignment +_object_setattr = object.__setattr__ + + +@attr.define(slots=True) +class PromptTemplate: + template: str + _input_variables: t.Sequence[str] = attr.field(init=False) + + def __attrs_post_init__(self) -> None: + self._input_variables = default_formatter.extract_template_variables(self.template) + + def with_options(self, **attrs: t.Any) -> PromptTemplate: + prompt_variables = {key: '{' + key + '}' if key not in attrs else attrs[key] for key in self._input_variables} + o = attr.evolve(self, template=self.template.format(**prompt_variables)) + _object_setattr(o, '_input_variables', default_formatter.extract_template_variables(o.template)) + return o + + def format(self, **attrs: t.Any) -> str: + prompt_variables = {k: v for k, v in attrs.items() if k in self._input_variables} + try: + return self.template.format(**prompt_variables) + except KeyError as e: + raise RuntimeError( + f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template." + ) from None + + @click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS) -@click.argument( - 'model_name', - type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]), - shell_complete=model_complete_envvar, -) +@click.argument('model_id', shell_complete=model_complete_envvar) @click.argument('prompt', type=click.STRING) -@click.option('--format', type=click.STRING, default=None) +@click.option('--prompt-template-file', type=click.File(), default=None) +@click.option('--chat-template-file', type=click.File(), default=None) +@click.option('--system-message', type=str, default=None) +@click.option( + '--add-generation-prompt/--no-add-generation-prompt', + default=False, + help='See https://huggingface.co/docs/transformers/main/chat_templating#what-template-should-i-use. This only applicable if model-id is a HF model_id', +) @click.option( '--opt', - help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)", + help="Define additional prompt variables. (format: ``--opt system_message='You are a useful assistant'``)", required=False, multiple=True, callback=opt_callback, @@ -35,45 +84,96 @@ logger = logging.getLogger(__name__) ) @click.pass_context def cli( - ctx: click.Context, /, model_name: str, prompt: str, format: str | None, _memoized: dict[str, t.Any], **_: t.Any + ctx: click.Context, + /, + model_id: str, + prompt: str, + prompt_template_file: t.IO[t.Any] | None, + chat_template_file: t.IO[t.Any] | None, + system_message: str | None, + add_generation_prompt: bool, + _memoized: dict[str, t.Any], + **_: t.Any, ) -> str | None: - """Get the default prompt used by OpenLLM.""" - module = getattr(openllm_core.config, f'configuration_{model_name}') + """Helpers for generating prompts. + + \b + It accepts remote HF model_ids as well as model name passed to `openllm start`. + + If you pass in a HF model_id, then it will use the tokenizer to generate the prompt. + + ```bash + openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there" + ``` + + If you need change the prompt template, you can create the template file that contains the jina2 template through `--chat-template-file` + See https://huggingface.co/docs/transformers/main/chat_templating#templates-for-chat-models for more details. + + \b + ```bash + openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there" --chat-template-file template.jinja2 + ``` + + \b + + If you pass a model name, then it will use OpenLLM configuration to generate the prompt. + Note that this is mainly for utilities, as OpenLLM won't use these prompts to format for you. + + \b + ```bash + openllm get-prompt mistral "Hello there" + """ _memoized = {k: v[0] for k, v in _memoized.items() if v} - try: - template = getattr(module, 'DEFAULT_PROMPT_TEMPLATE', None) - prompt_mapping = getattr(module, 'PROMPT_MAPPING', None) - if template is None: - raise click.BadArgumentUsage(f'model {model_name} does not have a default prompt template') from None - if callable(template): - if format is None: - if not hasattr(module, 'PROMPT_MAPPING') or module.PROMPT_MAPPING is None: - raise RuntimeError('Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.') - raise click.BadOptionUsage( - 'format', - f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})", - ) - if prompt_mapping is None: - raise click.BadArgumentUsage( - f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.' - ) from None - if format not in prompt_mapping: - raise click.BadOptionUsage( - 'format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})' - ) - _prompt_template = template(format) - else: - _prompt_template = template + + if prompt_template_file and chat_template_file: + ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.') + + acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set( + inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys() + ) + if model_id in acceptable: + logger.warning( + 'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n' + ) + config = openllm.AutoConfig.for_model(model_id) + template = prompt_template_file.read() if prompt_template_file is not None else config.template + system_message = system_message or config.system_message + try: - # backward-compatible. TO BE REMOVED once every model has default system message and prompt template. - fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized) + formatted = ( + PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized) + ) except RuntimeError as err: logger.debug('Exception caught while formatting prompt: %s', err) - fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters( - prompt, prompt_template=_prompt_template - )[0] - termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white') - except Exception as err: - traceback.print_exc() - raise click.ClickException(f'Failed to determine a default prompt template for {model_name}.') from err + ctx.fail(str(err)) + else: + import transformers + + trust_remote_code = openllm.utils.check_bool_env('TRUST_REMOTE_CODE', False) + config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code) + if chat_template_file is not None: + chat_template_file = chat_template_file.read() + if system_message is None: + logger.warning('system-message is not provided, using default infer from the model architecture.\n') + for architecture in config.architectures: + if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE(): + system_message = ( + openllm.AutoConfig.infer_class_from_name( + openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture] + ) + .model_construct_env() + .system_message + ) + break + else: + ctx.fail( + f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message' + ) + messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}] + formatted = tokenizer.apply_chat_template( + messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False + ) + + termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white') ctx.exit(0) diff --git a/openllm-python/tests/configuration_test.py b/openllm-python/tests/configuration_test.py index 851a5322..a2c28de8 100644 --- a/openllm-python/tests/configuration_test.py +++ b/openllm-python/tests/configuration_test.py @@ -171,7 +171,7 @@ def test_click_conversion(gen_settings: ModelSettings): return attrs cl_ = make_llm_config('ClickConversionLLM', gen_settings) - wrapped = cl_.to_click_options(cli_mock) + wrapped = cl_.parse(cli_mock) filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union} click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith('fake_')] assert len(filtered) == len(click_options_filtered)