feat: PromptTemplate and system prompt support (#407)

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
MingLiangDai
2023-10-03 09:53:37 -04:00
committed by GitHub
parent 43576fc8bb
commit a0e0f81306
24 changed files with 227 additions and 63 deletions

View File

@@ -101,9 +101,11 @@ from ._factory import model_id_option
from ._factory import model_name_argument
from ._factory import model_version_option
from ._factory import output_option
from ._factory import prompt_template_file_option
from ._factory import quantize_option
from ._factory import serialisation_option
from ._factory import start_command_factory
from ._factory import system_message_option
from ._factory import workers_per_resource_option
if t.TYPE_CHECKING:
@@ -426,6 +428,8 @@ def import_command(
@output_option
@machine_option
@backend_option
@system_message_option
@prompt_template_file_option
@click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.')
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
@workers_per_resource_option(factory=click, build=True)
@@ -478,6 +482,8 @@ def build_command(
adapter_id: tuple[str, ...],
build_ctx: str | None,
backend: LiteralBackend,
system_message: str | None,
prompt_template_file: t.IO[t.Any] | None,
machine: bool,
model_version: str | None,
dockerfile_template: t.TextIO | None,
@@ -514,6 +520,7 @@ def build_command(
llm_config = AutoConfig.for_model(model_name)
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
env = EnvVarMixin(model_name, backend=backend, model_id=model_id, quantize=quantize)
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
@@ -521,10 +528,10 @@ def build_command(
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': _serialisation, env.backend: env['backend_value']})
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
llm = infer_auto_class(env['backend_value']).for_model(
model_name, model_id=env['model_id_value'], llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs
)
llm = infer_auto_class(env['backend_value']).for_model(model_name, model_id=env['model_id_value'], prompt_template=prompt_template, system_message=system_message, llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs)
labels = dict(llm.identifying_params)
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})