refactor: move Prompt object to client specific attributes

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-05-25 22:58:25 +00:00
committed by Aaron
parent 545515c01f
commit 20b3a0260f
3 changed files with 72 additions and 52 deletions

View File

@@ -16,7 +16,6 @@ Schema definition for OpenLLM. This can be use for client interaction.
"""
from __future__ import annotations
import string
import typing as t
import inflection
@@ -25,56 +24,6 @@ import pydantic
import openllm
class PromptFormatter(string.Formatter):
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> str:
if len(args) > 0:
raise ValueError("Positional arguments are not supported")
return super().vformat(format_string, args, kwargs)
def check_unused_args(
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
) -> None:
"""Check if extra params is passed."""
extras = set(kwargs).difference(used_args)
if extras:
raise KeyError(f"Extra params passed: {extras}")
def extract_template_variables(self, template: str) -> t.Sequence[str]:
"""Extract template variables from a template string."""
return [field[1] for field in self.parse(template) if field[1] is not None]
# TODO: Support jinja2 template, go template and possible other prompt template engine.
_default_formatter = PromptFormatter()
class PromptTemplate(pydantic.BaseModel):
template: str
input_variables: t.Sequence[str]
model_config = {"extra": "forbid"}
def to_str(self, **kwargs: str) -> str:
"""Generate a prompt from the template and input variables"""
if not kwargs:
raise ValueError("Keyword arguments are required")
if not all(k in kwargs for k in self.input_variables):
raise ValueError(f"Missing required input variables: {self.input_variables}")
return _default_formatter.format(self.template, **kwargs)
@classmethod
def from_template(cls, template: str) -> PromptTemplate:
input_variables = _default_formatter.extract_template_variables(template)
return cls(template=template, input_variables=input_variables)
@classmethod
def from_default(cls, model: str) -> PromptTemplate:
template = getattr(openllm.utils.get_lazy_module(model), "DEFAULT_PROMPT_TEMPLATE")
if template is None:
raise ValueError(f"Model {model} does not have a default prompt template.")
return cls.from_template(template)
class GenerationInput(pydantic.BaseModel):
model_config = {"extra": "forbid"}

View File

@@ -23,7 +23,7 @@ from __future__ import annotations
import logging
import typing as t
import bentoml
from ._prompt import PromptTemplate as PromptTemplate
logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,71 @@
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import string
import typing as t
import pydantic
import openllm
class PromptFormatter(string.Formatter):
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> str:
if len(args) > 0:
raise ValueError("Positional arguments are not supported")
return super().vformat(format_string, args, kwargs)
def check_unused_args(
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
) -> None:
"""Check if extra params is passed."""
extras = set(kwargs).difference(used_args)
if extras:
raise KeyError(f"Extra params passed: {extras}")
def extract_template_variables(self, template: str) -> t.Sequence[str]:
"""Extract template variables from a template string."""
return [field[1] for field in self.parse(template) if field[1] is not None]
# TODO: Support jinja2 template, go template and possible other prompt template engine.
_default_formatter = PromptFormatter()
class PromptTemplate(pydantic.BaseModel):
template: str
input_variables: t.Sequence[str]
model_config = {"extra": "forbid"}
def to_str(self, **kwargs: str) -> str:
"""Generate a prompt from the template and input variables"""
if not kwargs:
raise ValueError("Keyword arguments are required")
if not all(k in kwargs for k in self.input_variables):
raise ValueError(f"Missing required input variables: {self.input_variables}")
return _default_formatter.format(self.template, **kwargs)
@classmethod
def from_template(cls, template: str) -> PromptTemplate:
input_variables = _default_formatter.extract_template_variables(template)
return cls(template=template, input_variables=input_variables)
@classmethod
def from_default(cls, model: str) -> PromptTemplate:
template = getattr(openllm.utils.get_lazy_module(model), "DEFAULT_PROMPT_TEMPLATE")
if template is None:
raise ValueError(f"Model {model} does not have a default prompt template.")
return cls.from_template(template)