diff --git a/src/openllm/_schema.py b/src/openllm/_schema.py index f0bc5e7e..2934cc6a 100644 --- a/src/openllm/_schema.py +++ b/src/openllm/_schema.py @@ -16,7 +16,6 @@ Schema definition for OpenLLM. This can be use for client interaction. """ from __future__ import annotations -import string import typing as t import inflection @@ -25,56 +24,6 @@ import pydantic import openllm -class PromptFormatter(string.Formatter): - def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> str: - if len(args) > 0: - raise ValueError("Positional arguments are not supported") - return super().vformat(format_string, args, kwargs) - - def check_unused_args( - self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any] - ) -> None: - """Check if extra params is passed.""" - extras = set(kwargs).difference(used_args) - if extras: - raise KeyError(f"Extra params passed: {extras}") - - def extract_template_variables(self, template: str) -> t.Sequence[str]: - """Extract template variables from a template string.""" - return [field[1] for field in self.parse(template) if field[1] is not None] - - -# TODO: Support jinja2 template, go template and possible other prompt template engine. -_default_formatter = PromptFormatter() - - -class PromptTemplate(pydantic.BaseModel): - template: str - input_variables: t.Sequence[str] - - model_config = {"extra": "forbid"} - - def to_str(self, **kwargs: str) -> str: - """Generate a prompt from the template and input variables""" - if not kwargs: - raise ValueError("Keyword arguments are required") - if not all(k in kwargs for k in self.input_variables): - raise ValueError(f"Missing required input variables: {self.input_variables}") - return _default_formatter.format(self.template, **kwargs) - - @classmethod - def from_template(cls, template: str) -> PromptTemplate: - input_variables = _default_formatter.extract_template_variables(template) - return cls(template=template, input_variables=input_variables) - - @classmethod - def from_default(cls, model: str) -> PromptTemplate: - template = getattr(openllm.utils.get_lazy_module(model), "DEFAULT_PROMPT_TEMPLATE") - if template is None: - raise ValueError(f"Model {model} does not have a default prompt template.") - return cls.from_template(template) - - class GenerationInput(pydantic.BaseModel): model_config = {"extra": "forbid"} diff --git a/src/openllm_client/__init__.py b/src/openllm_client/__init__.py index 89359b7c..959dc5ed 100644 --- a/src/openllm_client/__init__.py +++ b/src/openllm_client/__init__.py @@ -23,7 +23,7 @@ from __future__ import annotations import logging import typing as t -import bentoml +from ._prompt import PromptTemplate as PromptTemplate logger = logging.getLogger(__name__) diff --git a/src/openllm_client/_prompt.py b/src/openllm_client/_prompt.py new file mode 100644 index 00000000..828893fa --- /dev/null +++ b/src/openllm_client/_prompt.py @@ -0,0 +1,71 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import string +import typing as t + +import pydantic + +import openllm + + +class PromptFormatter(string.Formatter): + def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> str: + if len(args) > 0: + raise ValueError("Positional arguments are not supported") + return super().vformat(format_string, args, kwargs) + + def check_unused_args( + self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any] + ) -> None: + """Check if extra params is passed.""" + extras = set(kwargs).difference(used_args) + if extras: + raise KeyError(f"Extra params passed: {extras}") + + def extract_template_variables(self, template: str) -> t.Sequence[str]: + """Extract template variables from a template string.""" + return [field[1] for field in self.parse(template) if field[1] is not None] + + +# TODO: Support jinja2 template, go template and possible other prompt template engine. +_default_formatter = PromptFormatter() + + +class PromptTemplate(pydantic.BaseModel): + template: str + input_variables: t.Sequence[str] + + model_config = {"extra": "forbid"} + + def to_str(self, **kwargs: str) -> str: + """Generate a prompt from the template and input variables""" + if not kwargs: + raise ValueError("Keyword arguments are required") + if not all(k in kwargs for k in self.input_variables): + raise ValueError(f"Missing required input variables: {self.input_variables}") + return _default_formatter.format(self.template, **kwargs) + + @classmethod + def from_template(cls, template: str) -> PromptTemplate: + input_variables = _default_formatter.extract_template_variables(template) + return cls(template=template, input_variables=input_variables) + + @classmethod + def from_default(cls, model: str) -> PromptTemplate: + template = getattr(openllm.utils.get_lazy_module(model), "DEFAULT_PROMPT_TEMPLATE") + if template is None: + raise ValueError(f"Model {model} does not have a default prompt template.") + return cls.from_template(template)