mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-22 14:31:26 -05:00
fix(perf): respect per request information
remove use_default_prompt_template options add pretrained to list of start help docstring fix flax generation config improve flax and tensorflow implementation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -16,8 +16,6 @@ classifiers = [
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
@@ -71,7 +69,7 @@ all = ["openllm[fine-tune]", "openllm[flan-t5]", "openllm[chatglm]", "openllm[st
|
||||
chatglm = ["cpm_kernels", "sentencepiece"]
|
||||
falcon = ["einops", "xformers", "safetensors"]
|
||||
fine-tune = ["peft", "bitsandbytes", "datasets", "accelerate"]
|
||||
flan-t5 = ["flax", "jax", "jaxlib", "tensorflow"]
|
||||
flan-t5 = ["flax", "jax", "jaxlib", "tensorflow", "keras"]
|
||||
starcoder = ["bitsandbytes"]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -48,6 +48,7 @@ import inflection
|
||||
import orjson
|
||||
from cattr.gen import make_dict_unstructure_fn, override
|
||||
from click_option_group import optgroup
|
||||
from deepmerge.merger import Merger
|
||||
|
||||
import openllm
|
||||
|
||||
@@ -83,6 +84,15 @@ __all__ = ["LLMConfig"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config_merger = Merger(
|
||||
# merge dicts
|
||||
type_strategies=[(DictStrAny, "merge")],
|
||||
# override all other types
|
||||
fallback_strategies=["override"],
|
||||
# override conflicting types
|
||||
type_conflict_strategies=["override"],
|
||||
)
|
||||
|
||||
|
||||
@t.overload
|
||||
def attrs_to_options(
|
||||
@@ -593,13 +603,9 @@ def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConf
|
||||
return generated_cls
|
||||
|
||||
|
||||
USE_DEFAULT_PROMPT_TEMPLATE_DOCSTRING = """\
|
||||
Whether a model should use their default prompt template setup. This is useful if
|
||||
users wants to do some prompt engineering. Default to True.
|
||||
"""
|
||||
|
||||
# NOTE: This DEFAULT_KEYMAP is a way to dynamically generate attr.field
|
||||
DEFAULT_LLMCONFIG_ATTRS = (("use_default_prompt_template", True, USE_DEFAULT_PROMPT_TEMPLATE_DOCSTRING, bool),)
|
||||
# NOTE: This DEFAULT_LLMCONFIG_ATTRS is a way to dynamically generate attr.field
|
||||
# and will be saved for future use in LLMConfig if we have some shared config.
|
||||
DEFAULT_LLMCONFIG_ATTRS: tuple[tuple[str, t.Any, str, type[t.Any]], ...] = ()
|
||||
|
||||
|
||||
@attr.define
|
||||
@@ -652,6 +658,9 @@ class LLMConfig:
|
||||
__openllm_url__: str = Field(None, init=False)
|
||||
"""The resolved url for this LLMConfig."""
|
||||
|
||||
__openllm_accepted_keys__: set[str] = Field(None, init=False)
|
||||
"""The accepted keys for this LLMConfig."""
|
||||
|
||||
__openllm_requirements__: list[str] | None = None
|
||||
"""The default PyPI requirements needed to run this given LLM. By default, we will depend on
|
||||
bentoml, torch, transformers."""
|
||||
@@ -674,10 +683,6 @@ class LLMConfig:
|
||||
"""The result generated GenerationConfig class for this LLMConfig. This will be used
|
||||
to create the generation_config argument that can be used throughout the lifecycle."""
|
||||
|
||||
# NOTE: The following can be shared accross all LLMConfig subclasses.
|
||||
use_default_prompt_template: bool = Field(True, init=False)
|
||||
use_default_prompt_template.__doc__ = USE_DEFAULT_PROMPT_TEMPLATE_DOCSTRING
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
*,
|
||||
@@ -757,11 +762,15 @@ class LLMConfig:
|
||||
cls.__openllm_attrs__ = tuple(a.name for a in own_attrs)
|
||||
|
||||
# NOTE: Enable some default attributes that can be shared across all LLMConfig
|
||||
base_attrs = [
|
||||
attr.Attribute.from_counting_attr(k, cls.Field(default, env=field_env_key(k), description=docs), hints)
|
||||
for k, default, docs, hints in DEFAULT_LLMCONFIG_ATTRS
|
||||
if k not in cls.__openllm_attrs__
|
||||
] + base_attrs
|
||||
if len(DEFAULT_LLMCONFIG_ATTRS) > 0:
|
||||
# NOTE: update the hints for default variables we dynamically added.
|
||||
hints.update({k: hints for k, _, _, hints in DEFAULT_LLMCONFIG_ATTRS})
|
||||
base_attrs = [
|
||||
attr.Attribute.from_counting_attr(k, cls.Field(default, env=field_env_key(k), description=docs), hints)
|
||||
for k, default, docs, hints in DEFAULT_LLMCONFIG_ATTRS
|
||||
if k not in cls.__openllm_attrs__
|
||||
] + base_attrs
|
||||
|
||||
attrs: list[attr.Attribute[t.Any]] = own_attrs + base_attrs
|
||||
|
||||
# Mandatory vs non-mandatory attr order only matters when they are part of
|
||||
@@ -817,10 +826,10 @@ class LLMConfig:
|
||||
|
||||
hints.update(t.get_type_hints(cls.generation_class))
|
||||
|
||||
# NOTE: update the hints for default variables we dynamically added.
|
||||
hints.update({k: hints for k, _, _, hints in DEFAULT_LLMCONFIG_ATTRS})
|
||||
cls.__openllm_hints__ = hints
|
||||
|
||||
cls.__openllm_accepted_keys__ = set(cls.__openllm_attrs__) | set(attr.fields_dict(cls.generation_class))
|
||||
|
||||
@property
|
||||
def name_type(self) -> t.Literal["dasherize", "lowercase"]:
|
||||
return self.__openllm_name_type__
|
||||
@@ -832,8 +841,10 @@ class LLMConfig:
|
||||
__openllm_extras__: dict[str, t.Any] | None = None,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
to_exclude = list(attr.fields_dict(self.generation_class)) + list(self.__openllm_attrs__)
|
||||
self.__openllm_extras__ = __openllm_extras__ or {k: v for k, v in attrs.items() if k not in to_exclude}
|
||||
self.__openllm_extras__ = openllm.utils.first_not_none(__openllm_extras__, default={})
|
||||
config_merger.merge(
|
||||
self.__openllm_extras__, {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__}
|
||||
)
|
||||
|
||||
attrs = {k: v for k, v in attrs.items() if k not in self.__openllm_extras__ and v is not None}
|
||||
|
||||
@@ -844,9 +855,11 @@ class LLMConfig:
|
||||
|
||||
attrs = {k: v for k, v in attrs.items() if k not in generation_config}
|
||||
|
||||
extras = set(attrs).difference(set(attr.fields_dict(self.__class__)))
|
||||
self.__attrs_init__(**{k: v for k, v in attrs.items() if k in self.__openllm_attrs__})
|
||||
|
||||
self.__attrs_init__(**{k: v for k, v in attrs.items() if k not in extras})
|
||||
# The rest update to extras
|
||||
attrs = {k: v for k, v in attrs.items() if k not in self.__openllm_attrs__}
|
||||
config_merger.merge(self.__openllm_extras__, attrs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={repr(self.generation_class())}"
|
||||
@@ -897,35 +910,35 @@ class LLMConfig:
|
||||
return orjson.dumps(self.model_dump(**kwargs))
|
||||
|
||||
@classmethod
|
||||
def model_construct_env(cls, __llm_config__: LLMConfig | None = None, **attrs: t.Any) -> LLMConfig:
|
||||
def model_construct_env(cls, **attrs: t.Any) -> LLMConfig:
|
||||
"""A helpers that respect configuration values that
|
||||
sets from environment variables for any given configuration class.
|
||||
"""
|
||||
# NOTE: filter out None values
|
||||
attrs = {k: v for k, v in attrs.items() if v is not None}
|
||||
if "generation_config" in attrs:
|
||||
# NOTE: We will need to flatten the attrs dict
|
||||
generation_config = attrs.pop("generation_config", {})
|
||||
attrs.update(generation_config)
|
||||
|
||||
env = ModelEnv(cls.__openllm_model_name__)
|
||||
model_config = ModelEnv(cls.__openllm_model_name__).model_config
|
||||
|
||||
env_json_string = os.environ.get(env.model_config, None)
|
||||
env_json_string = os.environ.get(model_config, None)
|
||||
|
||||
if env_json_string is not None:
|
||||
try:
|
||||
config_from_env = orjson.loads(env_json_string)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise RuntimeError(f"Failed to parse '{env.model_config}' as valid JSON string.") from e
|
||||
config_from_env.update(attrs)
|
||||
return bentoml_cattr.structure(config_from_env, cls)
|
||||
raise RuntimeError(f"Failed to parse '{model_config}' as valid JSON string.") from e
|
||||
ncls = bentoml_cattr.structure(config_from_env, cls)
|
||||
else:
|
||||
ncls = cls()
|
||||
|
||||
if __llm_config__ is not None:
|
||||
# NOTE: We only hit this branch on server-side, to ensure per-request configuration
|
||||
# is respected.
|
||||
attrs.update(__llm_config__.model_dump(flatten=True))
|
||||
if "generation_config" in attrs:
|
||||
generation_config = attrs.pop("generation_config")
|
||||
if not LazyType(DictStrAny).isinstance(generation_config):
|
||||
raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}")
|
||||
else:
|
||||
generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(ncls.generation_class)}
|
||||
|
||||
return bentoml_cattr.structure(attrs, cls)
|
||||
attrs = {k: v for k, v in attrs.items() if k not in generation_config}
|
||||
ncls.generation_config = attr.evolve(ncls.generation_config, **generation_config)
|
||||
return attr.evolve(ncls, **attrs)
|
||||
|
||||
def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, dict[str, t.Any]]:
|
||||
"""Parse given click attributes into a LLMConfig and return the remaining click attributes."""
|
||||
@@ -1013,12 +1026,14 @@ def structure_llm_config(data: dict[str, t.Any], cls: type[LLMConfig]) -> LLMCon
|
||||
raise RuntimeError(f"Expected a dictionary, but got {type(data)}")
|
||||
|
||||
cls_attrs = {k: v for k, v in data.items() if k in cls.__openllm_attrs__}
|
||||
generation_cls_fields = attr.fields_dict(cls.generation_class)
|
||||
if "generation_config" in data:
|
||||
generation_config = data.pop("generation_config")
|
||||
if not LazyType(DictStrAny).isinstance(generation_config):
|
||||
raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}")
|
||||
config_merger.merge(generation_config, {k: v for k, v in data.items() if k in generation_cls_fields})
|
||||
else:
|
||||
generation_config = {k: v for k, v in data.items() if k in attr.fields_dict(cls.generation_class)}
|
||||
generation_config = {k: v for k, v in data.items() if k in generation_cls_fields}
|
||||
not_extras = list(cls_attrs) + list(generation_config)
|
||||
# The rest should be passed to extras
|
||||
data = {k: v for k, v in data.items() if k not in not_extras}
|
||||
|
||||
@@ -230,11 +230,7 @@ class LLMInterface(ABC):
|
||||
|
||||
It takes a prompt that is given by the user, attrs that can be parsed with the prompt.
|
||||
|
||||
NOTE: the attrs should also handle the following default attributes from all LLMConfig:
|
||||
- use_default_prompt_template
|
||||
|
||||
Returns a tuple of three items:
|
||||
- The processed prompt text depending on `use_default_prompt_template`
|
||||
- The attributes dictionary that can be passed into LLMConfig to generate a GenerationConfig
|
||||
- The attributes dictionary that will be passed into `self.postprocess_generate`.
|
||||
"""
|
||||
|
||||
@@ -39,14 +39,14 @@ svc = bentoml.Service(name=f"llm-{llm_config.__openllm_start_name__}-service", r
|
||||
|
||||
|
||||
@svc.api(
|
||||
input=bentoml.io.JSON.from_sample(sample={"prompt": "", "llm_config": {}}),
|
||||
output=bentoml.io.JSON.from_sample(sample={"responses": [], "configuration": {}}),
|
||||
input=bentoml.io.JSON.from_sample(sample={"prompt": "", "llm_config": llm_config.model_dump()}),
|
||||
output=bentoml.io.JSON.from_sample(sample={"responses": [], "configuration": llm_config.model_dump()}),
|
||||
route="/v1/generate",
|
||||
)
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
qa = openllm.GenerationInput.for_model(model)(**input_dict)
|
||||
config = llm_config.model_construct_env(__llm_config__=qa.llm_config).model_dump()
|
||||
responses = await runner.generate.async_run(qa.prompt, **config)
|
||||
qa_inputs = openllm.GenerationInput.for_model(model)(**input_dict)
|
||||
config = qa_inputs.llm_config.model_dump()
|
||||
responses = await runner.generate.async_run(qa_inputs.prompt, **config)
|
||||
return openllm.GenerationOutput(responses=responses, configuration=config)
|
||||
|
||||
|
||||
|
||||
@@ -314,15 +314,20 @@ def start_model_command(
|
||||
configure_logging()
|
||||
|
||||
ModelEnv = openllm.utils.ModelEnv(model_name)
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
|
||||
docstring = f"""\
|
||||
{ModelEnv.start_docstring}
|
||||
\b
|
||||
The available pretrained models to use with '{model_name}' are: {openllm.AutoLLM.for_model(model_name).pretrained}
|
||||
"""
|
||||
command_attrs: dict[str, t.Any] = {
|
||||
"name": ModelEnv.model_name,
|
||||
"context_settings": _context_settings or {},
|
||||
"short_help": f"Start a LLMServer for '{model_name}' ('--help' for more details)",
|
||||
"help": ModelEnv.start_docstring,
|
||||
"help": docstring,
|
||||
}
|
||||
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
|
||||
aliases: list[str] = []
|
||||
if llm_config.name_type == "dasherize":
|
||||
aliases.append(llm_config.__openllm_start_name__)
|
||||
|
||||
@@ -110,7 +110,6 @@ class ChatGLM(openllm.LLM):
|
||||
"num_beams": num_beams,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"use_default_prompt_template": use_default_prompt_template,
|
||||
**attrs,
|
||||
}
|
||||
|
||||
|
||||
@@ -42,7 +42,6 @@ class DollyV2Config(
|
||||
return_full_text: bool = openllm.LLMConfig.Field(
|
||||
False, description="Whether to return the full prompt to the users."
|
||||
)
|
||||
use_default_prompt_template: bool = False
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
|
||||
@@ -33,6 +33,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DollyV2(openllm.LLM):
|
||||
if t.TYPE_CHECKING:
|
||||
config: openllm.DollyV2Config
|
||||
|
||||
__openllm_internal__ = True
|
||||
|
||||
default_model = "databricks/dolly-v2-3b"
|
||||
@@ -58,12 +61,20 @@ class DollyV2(openllm.LLM):
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
)
|
||||
return bentoml.transformers.save_model(
|
||||
tag,
|
||||
pipeline,
|
||||
custom_objects={"tokenizer": tokenizer},
|
||||
external_modules=[importlib.import_module(pipeline.__module__)],
|
||||
)
|
||||
try:
|
||||
return bentoml.transformers.save_model(
|
||||
tag,
|
||||
pipeline,
|
||||
custom_objects={"tokenizer": tokenizer},
|
||||
external_modules=[importlib.import_module(pipeline.__module__)],
|
||||
)
|
||||
finally:
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
@@ -72,39 +83,37 @@ class DollyV2(openllm.LLM):
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if use_default_prompt_template:
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt)
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
# NOTE: The rest of attrs should be kwargs for GenerationConfig
|
||||
generate_kwargs = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"use_default_prompt_template": use_default_prompt_template,
|
||||
**attrs,
|
||||
}
|
||||
|
||||
return prompt_text, generate_kwargs, {}
|
||||
return prompt, generate_kwargs, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: str, **_: t.Any) -> str:
|
||||
return generation_result
|
||||
def postprocess_generate(
|
||||
self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any
|
||||
) -> str:
|
||||
return generation_result[0]["generated_text"]
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> str:
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
self.model.tokenizer = self.tokenizer
|
||||
llm_config: openllm.DollyV2Config = self.config.model_construct_env(**attrs)
|
||||
decoded = self.model(prompt, generation_config=llm_config.to_generation_config())
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
decoded: list[dict[t.Literal["generated_text"], str]] = self.model(
|
||||
prompt, generation_config=llm_config.to_generation_config()
|
||||
)
|
||||
|
||||
# If the full text is requested, then append the decoded text to the original instruction.
|
||||
# This technically isn't the full text, as we format the instruction in the prompt the model has been
|
||||
# trained on, but to the client it will appear to be the full text.
|
||||
if llm_config.return_full_text:
|
||||
decoded = f"{DEFAULT_PROMPT_TEMPLATE.format(prompt)}\n{decoded}"
|
||||
return [
|
||||
{k: f"{DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt)}\n{generated}"}
|
||||
for i in decoded
|
||||
for k, generated in i.items()
|
||||
]
|
||||
|
||||
return decoded
|
||||
|
||||
@@ -32,8 +32,6 @@ class FalconConfig(
|
||||
Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information.
|
||||
"""
|
||||
|
||||
use_default_prompt_template: bool = False
|
||||
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 200
|
||||
top_k: int = 10
|
||||
|
||||
@@ -88,17 +88,20 @@ class Falcon(openllm.LLM):
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if use_default_prompt_template:
|
||||
prompt_variables = {
|
||||
k: v
|
||||
for k, v in attrs.items()
|
||||
if k in default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE)
|
||||
}
|
||||
template_variables = default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if "instruction" in prompt_variables:
|
||||
raise RuntimeError(
|
||||
"'instruction' should be passed as the first argument instead of "
|
||||
"kwargs when 'use_default_prompt_template=True'"
|
||||
)
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables)
|
||||
try:
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
)
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -53,17 +53,20 @@ class FlanT5(openllm.LLM):
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if use_default_prompt_template:
|
||||
prompt_variables = {
|
||||
k: v
|
||||
for k, v in attrs.items()
|
||||
if k in default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE)
|
||||
}
|
||||
template_variables = default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if "instruction" in prompt_variables:
|
||||
raise RuntimeError(
|
||||
"'instruction' should be passed as the first argument "
|
||||
"instead of kwargs when 'use_default_prompt_template=True'"
|
||||
)
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables)
|
||||
try:
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
)
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -42,30 +42,38 @@ class FlaxFlanT5(openllm.LLM):
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
decoder_start_token_id: int | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if use_default_prompt_template:
|
||||
prompt_variables = {
|
||||
k: v
|
||||
for k, v in attrs.items()
|
||||
if k in default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE)
|
||||
}
|
||||
template_variables = default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if "instruction" in prompt_variables:
|
||||
raise RuntimeError(
|
||||
"'instruction' should be passed as the first argument "
|
||||
"instead of kwargs when 'use_default_prompt_template=True'"
|
||||
)
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables)
|
||||
try:
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
)
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
if decoder_start_token_id is None:
|
||||
decoder_start_token_id = 0
|
||||
|
||||
generation_config = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"decoder_start_token_id": decoder_start_token_id,
|
||||
}
|
||||
return prompt_text, generation_config, {}
|
||||
|
||||
@@ -73,11 +81,15 @@ class FlaxFlanT5(openllm.LLM):
|
||||
return generation_result[0]
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
# XXX: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main
|
||||
# as it is required for encoder-decoder generation.
|
||||
decoder_start_token_id = attrs.pop("decoder_start_token_id", 0)
|
||||
input_ids = self.tokenizer(prompt, return_tensors="np")["input_ids"]
|
||||
result_tensor = self.model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
)
|
||||
return self.tokenizer.batch_decode(
|
||||
result_tensor.sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
|
||||
@@ -65,6 +65,4 @@ SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
""" # noqa
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>""".format(
|
||||
system_prompt=SYSTEM_PROMPT, instruction="{instruction}"
|
||||
)
|
||||
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>"""
|
||||
|
||||
@@ -21,7 +21,7 @@ from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE, SYSTEM_PROMPT
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
@@ -81,7 +81,8 @@ class StableLM(openllm.LLM):
|
||||
"'instruction' should be passed as the first argument "
|
||||
"instead of kwargs when 'use_default_prompt_template=True'"
|
||||
)
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables)
|
||||
system_prompt = prompt_variables.pop("system_prompt", SYSTEM_PROMPT)
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, system_prompt=system_prompt)
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -20,8 +20,6 @@ import bentoml
|
||||
|
||||
import openllm
|
||||
|
||||
from .configuration_starcoder import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
import transformers
|
||||
@@ -103,7 +101,6 @@ class StarCoder(openllm.LLM):
|
||||
top_p: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
fim_mode = FIM_INDICATOR in prompt
|
||||
@@ -129,9 +126,6 @@ class StarCoder(openllm.LLM):
|
||||
**attrs,
|
||||
}
|
||||
|
||||
if use_default_prompt_template:
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt_text)
|
||||
|
||||
return prompt_text, generation_config, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
|
||||
@@ -34,6 +34,7 @@ from bentoml._internal.types import LazyType as LazyType
|
||||
from bentoml._internal.utils import LazyLoader as LazyLoader
|
||||
from bentoml._internal.utils import bentoml_cattr as bentoml_cattr
|
||||
from bentoml._internal.utils import copy_file_to_fs_folder as copy_file_to_fs_folder
|
||||
from bentoml._internal.utils import first_not_none as first_not_none
|
||||
from bentoml._internal.utils import pkg as pkg
|
||||
from bentoml._internal.utils import reserve_free_port as reserve_free_port
|
||||
from bentoml._internal.utils import resolve_user_filepath as resolve_user_filepath
|
||||
|
||||
@@ -161,7 +161,11 @@ class BaseAsyncClient(ClientMixin):
|
||||
...
|
||||
|
||||
async def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str:
|
||||
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(prompt, **attrs)
|
||||
# NOTE: We set use_default_prompt_template to False for now.
|
||||
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
|
||||
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(
|
||||
prompt, use_default_prompt_template=use_default_prompt_template, **attrs
|
||||
)
|
||||
inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
|
||||
res = await self.acall("generate", inputs)
|
||||
r = openllm.GenerationOutput(**res)
|
||||
|
||||
Reference in New Issue
Block a user