feat: PromptTemplate and system prompt support (#407)

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
MingLiangDai
2023-10-03 09:53:37 -04:00
committed by GitHub
parent 43576fc8bb
commit a0e0f81306
24 changed files with 227 additions and 63 deletions

View File

@@ -123,6 +123,8 @@ Available official model_id(s): [default: {llm_config['default_id']}]
server_timeout: int,
model_id: str | None,
model_version: str | None,
system_message: str | None,
prompt_template_file: t.IO[t.Any] | None,
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
device: t.Tuple[str, ...],
quantize: LiteralQuantise | None,
@@ -175,6 +177,7 @@ Available official model_id(s): [default: {llm_config['default_id']}]
start_env = os.environ.copy()
start_env = parse_config_options(config, server_timeout, wpr, device, cors, start_env)
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
start_env.update({
'OPENLLM_MODEL': model,
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
@@ -185,10 +188,14 @@ Available official model_id(s): [default: {llm_config['default_id']}]
})
if env['model_id_value']: start_env[env.model_id] = str(env['model_id_value'])
if env['quantize_value']: start_env[env.quantize] = str(env['quantize_value'])
if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
llm = openllm.utils.infer_auto_class(env['backend_value']).for_model(model,
model_id=start_env[env.model_id],
model_version=model_version,
prompt_template=prompt_template,
system_message=system_message,
llm_config=config,
ensure_available=True,
adapter_map=adapter_map,
@@ -233,6 +240,8 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
cog.optgroup.group('General LLM Options', help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
model_id_option(factory=cog.optgroup),
model_version_option(factory=cog.optgroup),
system_message_option(factory=cog.optgroup),
prompt_template_file_option(factory=cog.optgroup),
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
workers_per_resource_option(factory=cog.optgroup),
cors_option(factory=cog.optgroup),
@@ -377,6 +386,21 @@ def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option('--model-version', type=click.STRING, default=None, help='Optional model version to save for this model. It will be inferred automatically from model-id.', **attrs)(f)
def system_message_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option('--system-message',
type=click.STRING,
default=None,
envvar='OPENLLM_SYSTEM_MESSAGE',
help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.',
**attrs)(f)
def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option('--prompt-template-file',
type=click.File(),
default=None,
help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.',
**attrs)(f)
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
# XXX: remove the check for __args__ once we have ggml and mlc supports

View File

@@ -40,6 +40,8 @@ def _start(model_name: str,
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
device: tuple[str, ...] | t.Literal['all'] | None = None,
quantize: LiteralQuantise | None = None,
system_message: str | None = None,
prompt_template_file: str | None = None,
adapter_map: dict[LiteralString, str | None] | None = None,
backend: LiteralBackend | None = None,
additional_args: list[str] | None = None,
@@ -61,6 +63,8 @@ def _start(model_name: str,
model_name: The model name to start this LLM
model_id: Optional model id for this given LLM
timeout: The server timeout
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
workers_per_resource: Number of workers per resource assigned.
See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy)
for more information. By default, this is set to 1.
@@ -90,6 +94,8 @@ def _start(model_name: str,
args: list[str] = []
if model_id: args.extend(['--model-id', model_id])
if system_message: args.extend(['--system-message', system_message])
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
if timeout: args.extend(['--server-timeout', str(timeout)])
if workers_per_resource:
args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
@@ -113,6 +119,8 @@ def _build(model_name: str,
bento_version: str | None = None,
quantize: LiteralQuantise | None = None,
adapter_map: dict[str, str | None] | None = None,
system_message: str | None = None,
prompt_template_file: str | None = None,
build_ctx: str | None = None,
enable_features: tuple[str, ...] | None = None,
workers_per_resource: float | None = None,
@@ -137,6 +145,8 @@ def _build(model_name: str,
model_id: Optional model id for this given LLM
model_version: Optional model version for this given LLM
bento_version: Optional bento veresion for this given BentoLLM
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
quantize: Quantize the model weights. This is only applicable for PyTorch models.
Possible quantisation strategies:
- int8: Quantize the model with 8bit (bitsandbytes required)
@@ -181,6 +191,8 @@ def _build(model_name: str,
if enable_features: args.extend([f'--enable-features={f}' for f in enable_features])
if workers_per_resource: args.extend(['--workers-per-resource', str(workers_per_resource)])
if overwrite: args.append('--overwrite')
if system_message: args.extend(['--system-message', system_message])
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
if model_version: args.extend(['--model-version', model_version])
if bento_version: args.extend(['--bento-version', bento_version])

View File

@@ -101,9 +101,11 @@ from ._factory import model_id_option
from ._factory import model_name_argument
from ._factory import model_version_option
from ._factory import output_option
from ._factory import prompt_template_file_option
from ._factory import quantize_option
from ._factory import serialisation_option
from ._factory import start_command_factory
from ._factory import system_message_option
from ._factory import workers_per_resource_option
if t.TYPE_CHECKING:
@@ -426,6 +428,8 @@ def import_command(
@output_option
@machine_option
@backend_option
@system_message_option
@prompt_template_file_option
@click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.')
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
@workers_per_resource_option(factory=click, build=True)
@@ -478,6 +482,8 @@ def build_command(
adapter_id: tuple[str, ...],
build_ctx: str | None,
backend: LiteralBackend,
system_message: str | None,
prompt_template_file: t.IO[t.Any] | None,
machine: bool,
model_version: str | None,
dockerfile_template: t.TextIO | None,
@@ -514,6 +520,7 @@ def build_command(
llm_config = AutoConfig.for_model(model_name)
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
env = EnvVarMixin(model_name, backend=backend, model_id=model_id, quantize=quantize)
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
@@ -521,10 +528,10 @@ def build_command(
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': _serialisation, env.backend: env['backend_value']})
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
llm = infer_auto_class(env['backend_value']).for_model(
model_name, model_id=env['model_id_value'], llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs
)
llm = infer_auto_class(env['backend_value']).for_model(model_name, model_id=env['model_id_value'], prompt_template=prompt_template, system_message=system_message, llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs)
labels = dict(llm.identifying_params)
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})

View File

@@ -13,7 +13,7 @@ from openllm.cli import termui
from openllm.cli._factory import machine_option
from openllm.cli._factory import model_complete_envvar
from openllm.cli._factory import output_option
from openllm_core._prompt import process_prompt
from openllm_core.prompts import process_prompt
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
@@ -51,7 +51,11 @@ def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None,
_prompt_template = template(format)
else:
_prompt_template = template
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
try:
# backward-compatible. TO BE REMOVED once every model has default system message and prompt template.
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
except RuntimeError:
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(prompt, prompt_template=_prompt_template)[0]
if machine: return repr(fully_formatted)
elif output == 'porcelain': termui.echo(repr(fully_formatted), fg='white')
elif output == 'json':