mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-12 10:29:36 -04:00
chore: cleanup unused prompt templates (#713)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,33 +1,82 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import traceback
|
||||
import string
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
from bentoml_cli.utils import opt_callback
|
||||
|
||||
import openllm
|
||||
import openllm_core
|
||||
from openllm_cli import termui
|
||||
from openllm_cli._factory import model_complete_envvar
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptFormatter(string.Formatter):
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0:
|
||||
raise ValueError('Positional arguments are not supported')
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(
|
||||
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
|
||||
) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras:
|
||||
raise KeyError(f'Extra params passed: {extras}')
|
||||
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]:
|
||||
return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
|
||||
|
||||
default_formatter = PromptFormatter()
|
||||
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class PromptTemplate:
|
||||
template: str
|
||||
_input_variables: t.Sequence[str] = attr.field(init=False)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
self._input_variables = default_formatter.extract_template_variables(self.template)
|
||||
|
||||
def with_options(self, **attrs: t.Any) -> PromptTemplate:
|
||||
prompt_variables = {key: '{' + key + '}' if key not in attrs else attrs[key] for key in self._input_variables}
|
||||
o = attr.evolve(self, template=self.template.format(**prompt_variables))
|
||||
_object_setattr(o, '_input_variables', default_formatter.extract_template_variables(o.template))
|
||||
return o
|
||||
|
||||
def format(self, **attrs: t.Any) -> str:
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in self._input_variables}
|
||||
try:
|
||||
return self.template.format(**prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
|
||||
) from None
|
||||
|
||||
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument(
|
||||
'model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]),
|
||||
shell_complete=model_complete_envvar,
|
||||
)
|
||||
@click.argument('model_id', shell_complete=model_complete_envvar)
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@click.option('--format', type=click.STRING, default=None)
|
||||
@click.option('--prompt-template-file', type=click.File(), default=None)
|
||||
@click.option('--chat-template-file', type=click.File(), default=None)
|
||||
@click.option('--system-message', type=str, default=None)
|
||||
@click.option(
|
||||
'--add-generation-prompt/--no-add-generation-prompt',
|
||||
default=False,
|
||||
help='See https://huggingface.co/docs/transformers/main/chat_templating#what-template-should-i-use. This only applicable if model-id is a HF model_id',
|
||||
)
|
||||
@click.option(
|
||||
'--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
|
||||
help="Define additional prompt variables. (format: ``--opt system_message='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
@@ -35,45 +84,96 @@ logger = logging.getLogger(__name__)
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(
|
||||
ctx: click.Context, /, model_name: str, prompt: str, format: str | None, _memoized: dict[str, t.Any], **_: t.Any
|
||||
ctx: click.Context,
|
||||
/,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
chat_template_file: t.IO[t.Any] | None,
|
||||
system_message: str | None,
|
||||
add_generation_prompt: bool,
|
||||
_memoized: dict[str, t.Any],
|
||||
**_: t.Any,
|
||||
) -> str | None:
|
||||
"""Get the default prompt used by OpenLLM."""
|
||||
module = getattr(openllm_core.config, f'configuration_{model_name}')
|
||||
"""Helpers for generating prompts.
|
||||
|
||||
\b
|
||||
It accepts remote HF model_ids as well as model name passed to `openllm start`.
|
||||
|
||||
If you pass in a HF model_id, then it will use the tokenizer to generate the prompt.
|
||||
|
||||
```bash
|
||||
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there"
|
||||
```
|
||||
|
||||
If you need change the prompt template, you can create the template file that contains the jina2 template through `--chat-template-file`
|
||||
See https://huggingface.co/docs/transformers/main/chat_templating#templates-for-chat-models for more details.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there" --chat-template-file template.jinja2
|
||||
```
|
||||
|
||||
\b
|
||||
|
||||
If you pass a model name, then it will use OpenLLM configuration to generate the prompt.
|
||||
Note that this is mainly for utilities, as OpenLLM won't use these prompts to format for you.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm get-prompt mistral "Hello there"
|
||||
"""
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
try:
|
||||
template = getattr(module, 'DEFAULT_PROMPT_TEMPLATE', None)
|
||||
prompt_mapping = getattr(module, 'PROMPT_MAPPING', None)
|
||||
if template is None:
|
||||
raise click.BadArgumentUsage(f'model {model_name} does not have a default prompt template') from None
|
||||
if callable(template):
|
||||
if format is None:
|
||||
if not hasattr(module, 'PROMPT_MAPPING') or module.PROMPT_MAPPING is None:
|
||||
raise RuntimeError('Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.')
|
||||
raise click.BadOptionUsage(
|
||||
'format',
|
||||
f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})",
|
||||
)
|
||||
if prompt_mapping is None:
|
||||
raise click.BadArgumentUsage(
|
||||
f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.'
|
||||
) from None
|
||||
if format not in prompt_mapping:
|
||||
raise click.BadOptionUsage(
|
||||
'format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})'
|
||||
)
|
||||
_prompt_template = template(format)
|
||||
else:
|
||||
_prompt_template = template
|
||||
|
||||
if prompt_template_file and chat_template_file:
|
||||
ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.')
|
||||
|
||||
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(
|
||||
inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()
|
||||
)
|
||||
if model_id in acceptable:
|
||||
logger.warning(
|
||||
'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n'
|
||||
)
|
||||
config = openllm.AutoConfig.for_model(model_id)
|
||||
template = prompt_template_file.read() if prompt_template_file is not None else config.template
|
||||
system_message = system_message or config.system_message
|
||||
|
||||
try:
|
||||
# backward-compatible. TO BE REMOVED once every model has default system message and prompt template.
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
formatted = (
|
||||
PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
|
||||
)
|
||||
except RuntimeError as err:
|
||||
logger.debug('Exception caught while formatting prompt: %s', err)
|
||||
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(
|
||||
prompt, prompt_template=_prompt_template
|
||||
)[0]
|
||||
termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise click.ClickException(f'Failed to determine a default prompt template for {model_name}.') from err
|
||||
ctx.fail(str(err))
|
||||
else:
|
||||
import transformers
|
||||
|
||||
trust_remote_code = openllm.utils.check_bool_env('TRUST_REMOTE_CODE', False)
|
||||
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
if chat_template_file is not None:
|
||||
chat_template_file = chat_template_file.read()
|
||||
if system_message is None:
|
||||
logger.warning('system-message is not provided, using default infer from the model architecture.\n')
|
||||
for architecture in config.architectures:
|
||||
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
|
||||
system_message = (
|
||||
openllm.AutoConfig.infer_class_from_name(
|
||||
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
|
||||
)
|
||||
.model_construct_env()
|
||||
.system_message
|
||||
)
|
||||
break
|
||||
else:
|
||||
ctx.fail(
|
||||
f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message'
|
||||
)
|
||||
messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}]
|
||||
formatted = tokenizer.apply_chat_template(
|
||||
messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False
|
||||
)
|
||||
|
||||
termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
|
||||
Reference in New Issue
Block a user