chore: cleanup unused prompt templates (#713)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-21 01:56:51 -05:00
committed by GitHub
parent e6b9a749a4
commit fde78a2c78
39 changed files with 300 additions and 923 deletions

View File

@@ -77,8 +77,6 @@ def Runner(
'serialisation': first_not_none(
attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']
),
'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None),
'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None),
}
)

View File

@@ -25,7 +25,6 @@ from openllm_core._typing_compat import (
TupleAny,
)
from openllm_core.exceptions import MissingDependencyError
from openllm_core.prompts import PromptTemplate
from openllm_core.utils import (
DEBUG,
ENV_VARS_TRUE_VALUES,
@@ -129,8 +128,6 @@ class LLM(t.Generic[M, T], ReprMixin):
_serialisation: LiteralSerialisation
_local: bool
_max_model_len: int | None
_prompt_template: PromptTemplate | None
_system_message: str | None
__llm_dtype__: LiteralDtype | t.Literal['auto', 'half', 'float'] = 'auto'
__llm_torch_dtype__: 'torch.dtype' = None
@@ -148,8 +145,6 @@ class LLM(t.Generic[M, T], ReprMixin):
model_id,
model_version=None,
model_tag=None,
prompt_template=None,
system_message=None,
llm_config=None,
backend=None,
*args,
@@ -192,8 +187,6 @@ class LLM(t.Generic[M, T], ReprMixin):
serialisation=serialisation,
local=_local,
max_model_len=max_model_len,
prompt_template=PromptTemplate(prompt_template) if isinstance(prompt_template, str) else prompt_template,
system_message=system_message,
LLM__model_attrs=model_attrs,
LLM__tokenizer_attrs=tokenizer_attrs,
llm_dtype__=dtype.lower(),
@@ -480,16 +473,24 @@ class LLM(t.Generic[M, T], ReprMixin):
request_id = gen_random_uuid() if request_id is None else request_id
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
async for out in self.runner.generate_iterator.async_stream(
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
):
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
delta_outputs = [None] * len(generated.outputs)
if generated.finished:
break
for output in generated.outputs:
i = output.index
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
yield generated.with_options(outputs=delta_outputs)
try:
generator = self.runner.generate_iterator.async_stream(
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
)
except Exception as err:
raise RuntimeError(f'Failed to start generation task: {err}') from err
try:
async for out in generator:
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
delta_outputs = [None] * len(generated.outputs)
if generated.finished:
break
for output in generated.outputs:
i = output.index
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
yield generated.with_options(outputs=delta_outputs)
except Exception as err:
raise RuntimeError(f'Exception caught during generation: {err}') from err

View File

@@ -18,7 +18,6 @@ from openllm_core._typing_compat import (
M,
T,
)
from openllm_core.prompts import PromptTemplate
from openllm_core.utils.representation import ReprArgs
from ._quantisation import QuantizationConfig
@@ -48,8 +47,6 @@ class LLM(Generic[M, T]):
_adapter_map: Optional[AdapterMap]
_serialisation: LiteralSerialisation
_local: bool
_prompt_template: Optional[PromptTemplate]
_system_message: Optional[str]
__llm_dtype__: Dtype = ...
__llm_torch_dtype__: Optional[torch.dtype] = ...
@@ -74,8 +71,6 @@ class LLM(Generic[M, T]):
model_id: str,
model_version: Optional[str] = ...,
model_tag: Optional[Union[str, Tag]] = ...,
prompt_template: Optional[Union[str, PromptTemplate]] = ...,
system_message: Optional[str] = ...,
llm_config: Optional[LLMConfig] = ...,
backend: Optional[LiteralBackend] = ...,
*args: Any,

View File

@@ -27,7 +27,7 @@ def registry(cls=None, *, alias=None):
return decorator(cls)
def runner(llm):
def runner(llm: openllm.LLM):
from ._strategies import CascadingResourceStrategy
try:
@@ -35,19 +35,6 @@ def runner(llm):
except bentoml.exceptions.NotFound as err:
raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err
if llm._prompt_template:
prompt_template = llm._prompt_template.to_string()
elif hasattr(llm.config, 'default_prompt_template'):
prompt_template = llm.config.default_prompt_template
else:
prompt_template = None
if llm._system_message:
system_message = llm._system_message
elif hasattr(llm.config, 'default_system_message'):
system_message = llm.config.default_system_message
else:
system_message = None
return types.new_class(
llm.config.__class__.__name__[:-6] + 'Runner',
(bentoml.Runner,),
@@ -80,8 +67,8 @@ def runner(llm):
('llm_tag', llm.tag),
),
'has_adapters': llm.has_adapters,
'prompt_template': prompt_template,
'system_message': system_message,
'template': llm.config.template,
'system_message': llm.config.system_message,
}
),
)(
@@ -101,7 +88,7 @@ class CTranslateRunnable(bentoml.Runnable):
def __init__(self, llm):
if not is_ctranslate_available():
raise OpenLLMException('ctranslate is not installed. Please install it with `pip install "openllm[ctranslate]"`')
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.modle, llm.tokenizer
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):

View File

@@ -83,8 +83,8 @@ class Runner(Protocol[Mo, To]):
config: LLMConfig = ...
backend: LiteralBackend = ...
has_adapters: bool = ...
prompt_template: Optional[str] = ...
system_message: Optional[str] = ...
template: str = ...
system_message: str = ...
class generate_iterator(RunnerMethod[List[int], AsyncGenerator[str, None]]):
@staticmethod

View File

@@ -13,8 +13,6 @@ logger = logging.getLogger(__name__)
llm = openllm.LLM[t.Any, t.Any](
model_id=svars.model_id,
model_tag=svars.model_tag,
prompt_template=svars.prompt_template,
system_message=svars.system_message,
serialisation=svars.serialization,
adapter_map=svars.adapter_map,
trust_remote_code=svars.trust_remote_code,
@@ -50,8 +48,6 @@ _Metadata = openllm.MetadataOutput(
backend=llm.__llm_backend__,
model_id=llm.model_id,
configuration=llm.config.model_dump_json().decode(),
prompt_template=llm.runner.prompt_template,
system_message=llm.runner.system_message,
)

View File

@@ -7,7 +7,5 @@ from openllm_core.utils import ENV_VARS_TRUE_VALUES
model_id = os.environ['OPENLLM_MODEL_ID']
model_tag = None
adapter_map = orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None)))
prompt_template = os.getenv('OPENLLM_PROMPT_TEMPLATE')
system_message = os.getenv('OPENLLM_SYSTEM_MESSAGE')
serialization = os.getenv('OPENLLM_SERIALIZATION', default='safetensors')
trust_remote_code = str(os.getenv('TRUST_REMOTE_CODE', default=str(False))).upper() in ENV_VARS_TRUE_VALUES

View File

@@ -5,7 +5,5 @@ from openllm_core._typing_compat import LiteralSerialisation
model_id: str = ...
model_tag: Optional[str] = ...
adapter_map: Optional[Dict[str, str]] = ...
prompt_template: Optional[str] = ...
system_message: Optional[str] = ...
serialization: LiteralSerialisation = ...
trust_remote_code: bool = ...

View File

@@ -4,6 +4,4 @@ model_id = '{__model_id__}' # openllm: model id
model_tag = '{__model_tag__}' # openllm: model tag
adapter_map = orjson.loads("""{__model_adapter_map__}""") # openllm: model adapter map
serialization = '{__model_serialization__}' # openllm: model serialization
prompt_template = {__model_prompt_template__} # openllm: model prompt template
system_message = {__model_system_message__} # openllm: model system message
trust_remote_code = {__model_trust_remote_code__} # openllm: model trust remote code

View File

@@ -83,8 +83,6 @@ def construct_docker_options(
None,
llm._serialisation,
llm,
llm._system_message,
llm._prompt_template,
use_current_env=False,
)
# XXX: We need to quote this so that the envvar in container recognize as valid json
@@ -101,8 +99,6 @@ def construct_docker_options(
OPENLLM_MODEL_ID = '# openllm: model id'
OPENLLM_MODEL_TAG = '# openllm: model tag'
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
OPENLLM_MODEL_PROMPT_TEMPLATE = '# openllm: model prompt template'
OPENLLM_MODEL_SYSTEM_MESSAGE = '# openllm: model system message'
OPENLLM_MODEL_SERIALIZATION = '# openllm: model serialization'
OPENLLM_MODEL_TRUST_REMOTE_CODE = '# openllm: model trust remote code'
@@ -140,16 +136,6 @@ class ModelAdapterMapFormatter(_ServiceVarsFormatter):
identifier = OPENLLM_MODEL_ADAPTER_MAP
class ModelPromptTemplateFormatter(_ServiceVarsFormatter):
keyword = '__model_prompt_template__'
identifier = OPENLLM_MODEL_PROMPT_TEMPLATE
class ModelSystemMessageFormatter(_ServiceVarsFormatter):
keyword = '__model_system_message__'
identifier = OPENLLM_MODEL_SYSTEM_MESSAGE
class ModelSerializationFormatter(_ServiceVarsFormatter):
keyword = '__model_serialization__'
identifier = OPENLLM_MODEL_SERIALIZATION
@@ -187,16 +173,6 @@ def write_service(llm, llm_fs, adapter_map):
src_contents[i] = serialization_formatter.parse_line(it)
elif trust_remote_code_formatter.identifier in it:
src_contents[i] = trust_remote_code_formatter.parse_line(it)
elif OPENLLM_MODEL_PROMPT_TEMPLATE in it:
if llm._prompt_template:
src_contents[i] = ModelPromptTemplateFormatter(f'"""{llm._prompt_template.to_string()}"""').parse_line(it)
else:
src_contents[i] = ModelPromptTemplateFormatter(str(None)).parse_line(it)
elif OPENLLM_MODEL_SYSTEM_MESSAGE in it:
if llm._system_message:
src_contents[i] = ModelSystemMessageFormatter(f'"""{llm._system_message}"""').parse_line(it)
else:
src_contents[i] = ModelSystemMessageFormatter(str(None)).parse_line(it)
script = f"# GENERATED BY 'openllm build {llm.model_id}'. DO NOT EDIT\n\n" + ''.join(src_contents)
if SHOW_CODEGEN:

View File

@@ -105,7 +105,7 @@ async def cohere_generate(req, llm):
if err_check is not None:
return err_check
request_id = gen_random_uuid('cohere-generate')
config = llm.config.with_request(request)
config = llm.config.compatible_options(request)
if request.prompt_vars is not None:
prompt = request.prompt.format(**request.prompt_vars)
@@ -202,7 +202,7 @@ async def cohere_chat(req, llm):
_transpile_cohere_chat_messages(request), tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
)
logger.debug('Prompt: %r', prompt)
config = llm.config.with_request(request)
config = llm.config.compatible_options(request)
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)

View File

@@ -156,7 +156,7 @@ async def chat_completions(req, llm):
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
)
logger.debug('Prompt: %r', prompt)
config = llm.config.with_request(request)
config = llm.config.compatible_options(request)
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
@@ -272,14 +272,9 @@ async def completions(req, llm):
prompt = request.prompt
# TODO: Support multiple prompts
if request.logprobs is not None and llm.__llm_backend__ == 'pt': # TODO: support logprobs generation for PyTorch
return error_response(
HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`."
)
model_name, request_id = request.model, gen_random_uuid('cmpl')
created_time = int(time.monotonic())
config = llm.config.with_request(request)
config = llm.config.compatible_options(request)
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)

View File

@@ -135,13 +135,11 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
composed = openllm.utils.compose(
_OpenLLM_GenericInternalConfig().to_click_options,
_OpenLLM_GenericInternalConfig.parse,
_http_server_args if not serve_grpc else _grpc_server_args,
cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'),
dtype_option(factory=cog.optgroup),
model_version_option(factory=cog.optgroup),
system_message_option(factory=cog.optgroup),
prompt_template_file_option(factory=cog.optgroup),
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
workers_per_resource_option(factory=cog.optgroup),
cors_option(factory=cog.optgroup),
@@ -318,27 +316,6 @@ def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
)(f)
def system_message_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--system-message',
type=click.STRING,
default=None,
envvar='OPENLLM_SYSTEM_MESSAGE',
help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.',
**attrs,
)(f)
def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--prompt-template-file',
type=click.File(),
default=None,
help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.',
**attrs,
)(f)
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--backend',

View File

@@ -37,8 +37,6 @@ def _start(
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
device: tuple[str, ...] | t.Literal['all'] | None = None,
quantize: LiteralQuantise | None = None,
system_message: str | None = None,
prompt_template_file: str | None = None,
adapter_map: dict[LiteralString, str | None] | None = None,
backend: LiteralBackend | None = None,
additional_args: list[str] | None = None,
@@ -60,8 +58,6 @@ def _start(
Args:
model_id: The model id to start this LLMServer
timeout: The server timeout
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
workers_per_resource: Number of workers per resource assigned.
See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy)
for more information. By default, this is set to 1.
@@ -88,10 +84,6 @@ def _start(
os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
args: list[str] = [model_id]
if system_message:
args.extend(['--system-message', system_message])
if prompt_template_file:
args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
if timeout:
args.extend(['--server-timeout', str(timeout)])
if workers_per_resource:
@@ -129,8 +121,6 @@ def _build(
bento_version: str | None = None,
quantize: LiteralQuantise | None = None,
adapter_map: dict[str, str | None] | None = None,
system_message: str | None = None,
prompt_template_file: str | None = None,
build_ctx: str | None = None,
enable_features: tuple[str, ...] | None = None,
dockerfile_template: str | None = None,
@@ -155,8 +145,6 @@ def _build(
model_id: The model id to build this BentoLLM
model_version: Optional model version for this given LLM
bento_version: Optional bento veresion for this given BentoLLM
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
quantize: Quantize the model weights. This is only applicable for PyTorch models.
Possible quantisation strategies:
- int8: Quantize the model with 8bit (bitsandbytes required)
@@ -209,10 +197,6 @@ def _build(
args.extend([f'--enable-features={f}' for f in enable_features])
if overwrite:
args.append('--overwrite')
if system_message:
args.extend(['--system-message', system_message])
if prompt_template_file:
args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
if adapter_map:
args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
if model_version:

View File

@@ -100,11 +100,9 @@ from ._factory import (
model_name_argument,
model_version_option,
parse_config_options,
prompt_template_file_option,
quantize_option,
serialisation_option,
start_decorator,
system_message_option,
)
if t.TYPE_CHECKING:
@@ -404,8 +402,6 @@ def start_command(
model_id: str,
server_timeout: int,
model_version: str | None,
system_message: str | None,
prompt_template_file: t.IO[t.Any] | None,
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
device: t.Tuple[str, ...],
quantize: LiteralQuantise | None,
@@ -437,7 +433,6 @@ def start_command(
)
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
from openllm.serialisation.transformers.weights import has_safetensors_weights
@@ -467,8 +462,6 @@ def start_command(
llm = openllm.LLM[t.Any, t.Any](
model_id=model_id,
model_version=model_version,
prompt_template=prompt_template,
system_message=system_message,
backend=backend,
adapter_map=adapter_map,
quantize=quantize,
@@ -495,8 +488,6 @@ def start_command(
adapter_map,
serialisation,
llm,
system_message,
prompt_template,
)
server = bentoml.HTTPServer('_service:svc', **server_attrs)
@@ -541,8 +532,6 @@ def start_grpc_command(
model_id: str,
server_timeout: int,
model_version: str | None,
system_message: str | None,
prompt_template_file: t.IO[t.Any] | None,
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
device: t.Tuple[str, ...],
quantize: LiteralQuantise | None,
@@ -577,7 +566,6 @@ def start_grpc_command(
)
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
from openllm.serialisation.transformers.weights import has_safetensors_weights
@@ -604,8 +592,6 @@ def start_grpc_command(
llm = openllm.LLM[t.Any, t.Any](
model_id=model_id,
model_version=model_version,
prompt_template=prompt_template,
system_message=system_message,
backend=backend,
adapter_map=adapter_map,
quantize=quantize,
@@ -634,8 +620,6 @@ def start_grpc_command(
adapter_map,
serialisation,
llm,
system_message,
prompt_template,
)
server = bentoml.GrpcServer('_service:svc', **server_attrs)
@@ -654,18 +638,7 @@ def start_grpc_command(
def process_environ(
config,
server_timeout,
wpr,
device,
cors,
model_id,
adapter_map,
serialisation,
llm,
system_message,
prompt_template,
use_current_env=True,
config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, use_current_env=True
) -> t.Dict[str, t.Any]:
environ = parse_config_options(
config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {}
@@ -685,10 +658,6 @@ def process_environ(
)
if llm.quantise:
environ['QUANTIZE'] = str(llm.quantise)
if system_message:
environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template:
environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
return environ
@@ -929,8 +898,6 @@ class BuildBentoOutput(t.TypedDict):
)
@dtype_option
@backend_option
@system_message_option
@prompt_template_file_option
@click.option(
'--bento-version',
type=str,
@@ -1004,8 +971,6 @@ def build_command(
adapter_id: tuple[str, ...],
build_ctx: str | None,
backend: LiteralBackend | None,
system_message: str | None,
prompt_template_file: t.IO[t.Any] | None,
model_version: str | None,
dockerfile_template: t.TextIO | None,
containerize: bool,
@@ -1051,12 +1016,9 @@ def build_command(
state = ItemState.NOT_FOUND
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
llm = openllm.LLM[t.Any, t.Any](
model_id=model_id,
model_version=model_version,
prompt_template=prompt_template,
system_message=system_message,
backend=backend,
quantize=quantize,
dtype=dtype,
@@ -1075,19 +1037,7 @@ def build_command(
llm._tag = model.tag
os.environ.update(
**process_environ(
llm.config,
llm.config['timeout'],
1.0,
None,
True,
llm.model_id,
None,
llm._serialisation,
llm,
llm._system_message,
llm._prompt_template,
)
**process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm)
)
try:

View File

@@ -1,33 +1,82 @@
from __future__ import annotations
import logging
import traceback
import string
import typing as t
import attr
import click
import inflection
import orjson
from bentoml_cli.utils import opt_callback
import openllm
import openllm_core
from openllm_cli import termui
from openllm_cli._factory import model_complete_envvar
from openllm_core.prompts import process_prompt
logger = logging.getLogger(__name__)
class PromptFormatter(string.Formatter):
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
if len(args) > 0:
raise ValueError('Positional arguments are not supported')
return super().vformat(format_string, args, kwargs)
def check_unused_args(
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
) -> None:
extras = set(kwargs).difference(used_args)
if extras:
raise KeyError(f'Extra params passed: {extras}')
def extract_template_variables(self, template: str) -> t.Sequence[str]:
return [field[1] for field in self.parse(template) if field[1] is not None]
default_formatter = PromptFormatter()
# equivocal setattr to save one lookup per assignment
_object_setattr = object.__setattr__
@attr.define(slots=True)
class PromptTemplate:
template: str
_input_variables: t.Sequence[str] = attr.field(init=False)
def __attrs_post_init__(self) -> None:
self._input_variables = default_formatter.extract_template_variables(self.template)
def with_options(self, **attrs: t.Any) -> PromptTemplate:
prompt_variables = {key: '{' + key + '}' if key not in attrs else attrs[key] for key in self._input_variables}
o = attr.evolve(self, template=self.template.format(**prompt_variables))
_object_setattr(o, '_input_variables', default_formatter.extract_template_variables(o.template))
return o
def format(self, **attrs: t.Any) -> str:
prompt_variables = {k: v for k, v in attrs.items() if k in self._input_variables}
try:
return self.template.format(**prompt_variables)
except KeyError as e:
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
) from None
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
@click.argument(
'model_name',
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]),
shell_complete=model_complete_envvar,
)
@click.argument('model_id', shell_complete=model_complete_envvar)
@click.argument('prompt', type=click.STRING)
@click.option('--format', type=click.STRING, default=None)
@click.option('--prompt-template-file', type=click.File(), default=None)
@click.option('--chat-template-file', type=click.File(), default=None)
@click.option('--system-message', type=str, default=None)
@click.option(
'--add-generation-prompt/--no-add-generation-prompt',
default=False,
help='See https://huggingface.co/docs/transformers/main/chat_templating#what-template-should-i-use. This only applicable if model-id is a HF model_id',
)
@click.option(
'--opt',
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
help="Define additional prompt variables. (format: ``--opt system_message='You are a useful assistant'``)",
required=False,
multiple=True,
callback=opt_callback,
@@ -35,45 +84,96 @@ logger = logging.getLogger(__name__)
)
@click.pass_context
def cli(
ctx: click.Context, /, model_name: str, prompt: str, format: str | None, _memoized: dict[str, t.Any], **_: t.Any
ctx: click.Context,
/,
model_id: str,
prompt: str,
prompt_template_file: t.IO[t.Any] | None,
chat_template_file: t.IO[t.Any] | None,
system_message: str | None,
add_generation_prompt: bool,
_memoized: dict[str, t.Any],
**_: t.Any,
) -> str | None:
"""Get the default prompt used by OpenLLM."""
module = getattr(openllm_core.config, f'configuration_{model_name}')
"""Helpers for generating prompts.
\b
It accepts remote HF model_ids as well as model name passed to `openllm start`.
If you pass in a HF model_id, then it will use the tokenizer to generate the prompt.
```bash
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there"
```
If you need change the prompt template, you can create the template file that contains the jina2 template through `--chat-template-file`
See https://huggingface.co/docs/transformers/main/chat_templating#templates-for-chat-models for more details.
\b
```bash
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there" --chat-template-file template.jinja2
```
\b
If you pass a model name, then it will use OpenLLM configuration to generate the prompt.
Note that this is mainly for utilities, as OpenLLM won't use these prompts to format for you.
\b
```bash
openllm get-prompt mistral "Hello there"
"""
_memoized = {k: v[0] for k, v in _memoized.items() if v}
try:
template = getattr(module, 'DEFAULT_PROMPT_TEMPLATE', None)
prompt_mapping = getattr(module, 'PROMPT_MAPPING', None)
if template is None:
raise click.BadArgumentUsage(f'model {model_name} does not have a default prompt template') from None
if callable(template):
if format is None:
if not hasattr(module, 'PROMPT_MAPPING') or module.PROMPT_MAPPING is None:
raise RuntimeError('Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.')
raise click.BadOptionUsage(
'format',
f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})",
)
if prompt_mapping is None:
raise click.BadArgumentUsage(
f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.'
) from None
if format not in prompt_mapping:
raise click.BadOptionUsage(
'format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})'
)
_prompt_template = template(format)
else:
_prompt_template = template
if prompt_template_file and chat_template_file:
ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.')
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(
inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()
)
if model_id in acceptable:
logger.warning(
'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n'
)
config = openllm.AutoConfig.for_model(model_id)
template = prompt_template_file.read() if prompt_template_file is not None else config.template
system_message = system_message or config.system_message
try:
# backward-compatible. TO BE REMOVED once every model has default system message and prompt template.
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
formatted = (
PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
)
except RuntimeError as err:
logger.debug('Exception caught while formatting prompt: %s', err)
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(
prompt, prompt_template=_prompt_template
)[0]
termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
except Exception as err:
traceback.print_exc()
raise click.ClickException(f'Failed to determine a default prompt template for {model_name}.') from err
ctx.fail(str(err))
else:
import transformers
trust_remote_code = openllm.utils.check_bool_env('TRUST_REMOTE_CODE', False)
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
if chat_template_file is not None:
chat_template_file = chat_template_file.read()
if system_message is None:
logger.warning('system-message is not provided, using default infer from the model architecture.\n')
for architecture in config.architectures:
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
system_message = (
openllm.AutoConfig.infer_class_from_name(
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
)
.model_construct_env()
.system_message
)
break
else:
ctx.fail(
f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message'
)
messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}]
formatted = tokenizer.apply_chat_template(
messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False
)
termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
ctx.exit(0)

View File

@@ -171,7 +171,7 @@ def test_click_conversion(gen_settings: ModelSettings):
return attrs
cl_ = make_llm_config('ClickConversionLLM', gen_settings)
wrapped = cl_.to_click_options(cli_mock)
wrapped = cl_.parse(cli_mock)
filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union}
click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith('fake_')]
assert len(filtered) == len(click_options_filtered)