mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-29 03:13:44 -04:00
feat: PromptTemplate and system prompt support (#407)
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -123,6 +123,8 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
server_timeout: int,
|
||||
model_id: str | None,
|
||||
model_version: str | None,
|
||||
system_message: str | None,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: LiteralQuantise | None,
|
||||
@@ -175,6 +177,7 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
start_env = os.environ.copy()
|
||||
start_env = parse_config_options(config, server_timeout, wpr, device, cors, start_env)
|
||||
|
||||
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
|
||||
start_env.update({
|
||||
'OPENLLM_MODEL': model,
|
||||
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
|
||||
@@ -185,10 +188,14 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
})
|
||||
if env['model_id_value']: start_env[env.model_id] = str(env['model_id_value'])
|
||||
if env['quantize_value']: start_env[env.quantize] = str(env['quantize_value'])
|
||||
if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
|
||||
llm = openllm.utils.infer_auto_class(env['backend_value']).for_model(model,
|
||||
model_id=start_env[env.model_id],
|
||||
model_version=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_config=config,
|
||||
ensure_available=True,
|
||||
adapter_map=adapter_map,
|
||||
@@ -233,6 +240,8 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
|
||||
cog.optgroup.group('General LLM Options', help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
|
||||
model_id_option(factory=cog.optgroup),
|
||||
model_version_option(factory=cog.optgroup),
|
||||
system_message_option(factory=cog.optgroup),
|
||||
prompt_template_file_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
cors_option(factory=cog.optgroup),
|
||||
@@ -377,6 +386,21 @@ def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--model-version', type=click.STRING, default=None, help='Optional model version to save for this model. It will be inferred automatically from model-id.', **attrs)(f)
|
||||
|
||||
def system_message_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--system-message',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_SYSTEM_MESSAGE',
|
||||
help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.',
|
||||
**attrs)(f)
|
||||
|
||||
def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--prompt-template-file',
|
||||
type=click.File(),
|
||||
default=None,
|
||||
help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.',
|
||||
**attrs)(f)
|
||||
|
||||
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
|
||||
# XXX: remove the check for __args__ once we have ggml and mlc supports
|
||||
|
||||
@@ -40,6 +40,8 @@ def _start(model_name: str,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
@@ -61,6 +63,8 @@ def _start(model_name: str,
|
||||
model_name: The model name to start this LLM
|
||||
model_id: Optional model id for this given LLM
|
||||
timeout: The server timeout
|
||||
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
|
||||
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
|
||||
workers_per_resource: Number of workers per resource assigned.
|
||||
See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy)
|
||||
for more information. By default, this is set to 1.
|
||||
@@ -90,6 +94,8 @@ def _start(model_name: str,
|
||||
|
||||
args: list[str] = []
|
||||
if model_id: args.extend(['--model-id', model_id])
|
||||
if system_message: args.extend(['--system-message', system_message])
|
||||
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if timeout: args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource:
|
||||
args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
|
||||
@@ -113,6 +119,8 @@ def _build(model_name: str,
|
||||
bento_version: str | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
workers_per_resource: float | None = None,
|
||||
@@ -137,6 +145,8 @@ def _build(model_name: str,
|
||||
model_id: Optional model id for this given LLM
|
||||
model_version: Optional model version for this given LLM
|
||||
bento_version: Optional bento veresion for this given BentoLLM
|
||||
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
|
||||
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
|
||||
quantize: Quantize the model weights. This is only applicable for PyTorch models.
|
||||
Possible quantisation strategies:
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
@@ -181,6 +191,8 @@ def _build(model_name: str,
|
||||
if enable_features: args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if workers_per_resource: args.extend(['--workers-per-resource', str(workers_per_resource)])
|
||||
if overwrite: args.append('--overwrite')
|
||||
if system_message: args.extend(['--system-message', system_message])
|
||||
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version: args.extend(['--model-version', model_version])
|
||||
if bento_version: args.extend(['--bento-version', bento_version])
|
||||
|
||||
@@ -101,9 +101,11 @@ from ._factory import model_id_option
|
||||
from ._factory import model_name_argument
|
||||
from ._factory import model_version_option
|
||||
from ._factory import output_option
|
||||
from ._factory import prompt_template_file_option
|
||||
from ._factory import quantize_option
|
||||
from ._factory import serialisation_option
|
||||
from ._factory import start_command_factory
|
||||
from ._factory import system_message_option
|
||||
from ._factory import workers_per_resource_option
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -426,6 +428,8 @@ def import_command(
|
||||
@output_option
|
||||
@machine_option
|
||||
@backend_option
|
||||
@system_message_option
|
||||
@prompt_template_file_option
|
||||
@click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.')
|
||||
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
|
||||
@workers_per_resource_option(factory=click, build=True)
|
||||
@@ -478,6 +482,8 @@ def build_command(
|
||||
adapter_id: tuple[str, ...],
|
||||
build_ctx: str | None,
|
||||
backend: LiteralBackend,
|
||||
system_message: str | None,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
machine: bool,
|
||||
model_version: str | None,
|
||||
dockerfile_template: t.TextIO | None,
|
||||
@@ -514,6 +520,7 @@ def build_command(
|
||||
llm_config = AutoConfig.for_model(model_name)
|
||||
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
|
||||
env = EnvVarMixin(model_name, backend=backend, model_id=model_id, quantize=quantize)
|
||||
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
|
||||
|
||||
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
|
||||
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
|
||||
@@ -521,10 +528,10 @@ def build_command(
|
||||
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': _serialisation, env.backend: env['backend_value']})
|
||||
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
|
||||
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
|
||||
if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
|
||||
llm = infer_auto_class(env['backend_value']).for_model(
|
||||
model_name, model_id=env['model_id_value'], llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs
|
||||
)
|
||||
llm = infer_auto_class(env['backend_value']).for_model(model_name, model_id=env['model_id_value'], prompt_template=prompt_template, system_message=system_message, llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs)
|
||||
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})
|
||||
|
||||
@@ -13,7 +13,7 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import machine_option
|
||||
from openllm.cli._factory import model_complete_envvar
|
||||
from openllm.cli._factory import output_option
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
|
||||
@@ -51,7 +51,11 @@ def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None,
|
||||
_prompt_template = template(format)
|
||||
else:
|
||||
_prompt_template = template
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
try:
|
||||
# backward-compatible. TO BE REMOVED once every model has default system message and prompt template.
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
except RuntimeError:
|
||||
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(prompt, prompt_template=_prompt_template)[0]
|
||||
if machine: return repr(fully_formatted)
|
||||
elif output == 'porcelain': termui.echo(repr(fully_formatted), fg='white')
|
||||
elif output == 'json':
|
||||
|
||||
Reference in New Issue
Block a user