feat: PromptTemplate and system prompt support (#407)

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
MingLiangDai
2023-10-03 09:53:37 -04:00
committed by GitHub
parent 43576fc8bb
commit a0e0f81306
24 changed files with 227 additions and 63 deletions

View File

@@ -39,6 +39,7 @@ _import_structure: dict[str, list[str]] = {
"bundle": [],
"playground": [],
"testing": [],
"prompts": ["PromptTemplate"],
"utils": ["infer_auto_class"],
"serialisation": ["ggml", "transformers"],
"cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"],
@@ -70,6 +71,7 @@ if _t.TYPE_CHECKING:
from .cli._sdk import build as build, import_model as import_model, list_models as list_models, start as start, start_grpc as start_grpc
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
from .serialisation import ggml as ggml, transformers as transformers
from .prompts import PromptTemplate as PromptTemplate
from .utils import infer_auto_class as infer_auto_class
try:

View File

@@ -22,7 +22,6 @@ import openllm_core
from bentoml._internal.models.model import ModelSignature
from openllm_core._configuration import FineTuneConfig
from openllm_core._configuration import LLMConfig
from openllm_core._prompt import process_prompt
from openllm_core._schema import EmbeddingsOutput
from openllm_core._typing_compat import AdaptersMapping
from openllm_core._typing_compat import AdaptersTuple
@@ -40,6 +39,8 @@ from openllm_core._typing_compat import PeftAdapterOutput
from openllm_core._typing_compat import T
from openllm_core._typing_compat import TupleAny
from openllm_core._typing_compat import overload
from openllm_core.prompts import PromptTemplate
from openllm_core.prompts import process_prompt
from openllm_core.utils import DEBUG
from openllm_core.utils import MYPY
from openllm_core.utils import EnvVarMixin
@@ -293,6 +294,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
model_version: t.Optional[str],
serialisation: LiteralSerialisation,
_local: bool,
prompt_template: PromptTemplate | None,
system_message: str | None,
**attrs: t.Any) -> None:
'''Generated __attrs_init__ for openllm.LLM.'''
@@ -310,6 +313,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
_model_version: str
_serialisation: LiteralSerialisation
_local: bool
_prompt_template: PromptTemplate | None
_system_message: str | None
def __init_subclass__(cls: type[LLM[M, T]]) -> None:
cd = cls.__dict__
@@ -375,6 +380,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
def from_pretrained(cls,
model_id: str | None = None,
model_version: str | None = None,
prompt_template: PromptTemplate | str | None = None,
system_message: str | None = None,
llm_config: LLMConfig | None = None,
*args: t.Any,
quantize: LiteralQuantise | None = None,
@@ -421,6 +428,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
model_version: Optional version for this given model id. Default to None. This is useful for saving from custom path.
If set to None, the version will either be the git hash from given pretrained model, or the hash inferred
from last modified time of the given directory.
system_message: Optional system message for what the system prompt for the specified LLM is. If not given, the default system message will be used.
prompt_template: Optional custom prompt template. If not given, the default prompt template for the specified model will be used.
llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM
will use `config_class` to construct default configuration.
quantize: The quantization to use for this LLM. Defaults to None. Possible values
@@ -457,6 +466,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
if adapter_map is not None and not is_peft_available():
raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'")
if adapter_map: logger.debug('OpenLLM will apply the following adapters layers: %s', list(adapter_map))
if isinstance(prompt_template, str): prompt_template = PromptTemplate(prompt_template)
if llm_config is None:
llm_config = cls.config_class.model_construct_env(**attrs)
@@ -476,6 +486,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
quantization_config=quantization_config,
_quantize=quantize,
_model_version=_tag.version,
_prompt_template=prompt_template,
_system_message=system_message,
_tag=_tag,
_serialisation=serialisation,
_local=_local,
@@ -538,6 +550,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
_tag: bentoml.Tag,
_serialisation: LiteralSerialisation,
_local: bool,
_prompt_template: PromptTemplate | None,
_system_message: str | None,
_adapters_mapping: AdaptersMapping | None,
**attrs: t.Any,
):
@@ -650,7 +664,9 @@ class LLM(LLMInterface[M, T], ReprMixin):
_adapters_mapping,
_model_version,
_serialisation,
_local)
_local,
_prompt_template,
_system_message)
self.llm_post_init()
@@ -728,6 +744,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
- The attributes dictionary that can be passed into LLMConfig to generate a GenerationConfig
- The attributes dictionary that will be passed into `self.postprocess_generate`.
'''
attrs.update({'prompt_template': self._prompt_template, 'system_message': self._system_message})
return self.config.sanitize_parameters(prompt, **attrs)
def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> t.Any:
@@ -937,9 +954,10 @@ class LLM(LLMInterface[M, T], ReprMixin):
prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **attrs)
return self.postprocess_generate(prompt, self.generate(prompt, **generate_kwargs), **postprocess_kwargs)
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
def generate_one(self, prompt: str, stop: list[str], **attrs: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
prompt, generate_kwargs, _ = self.sanitize_parameters(prompt, **attrs)
max_new_tokens, encoded_inputs = generate_kwargs.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], generate_kwargs.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
# Inference API returns the stop sequence
@@ -949,6 +967,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
def generate(self, prompt: str, **attrs: t.Any) -> t.List[t.Any]:
# TODO: support different generation strategies, similar to self.model.generate
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
for it in self.generate_iterator(prompt, **attrs):
pass
return [it]
@@ -968,6 +987,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
from ._generation import is_partial_stop
from ._generation import prepare_logits_processor
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
len_prompt = len(prompt)
config = self.config.model_construct_env(**attrs)
if stop_token_ids is None: stop_token_ids = []
@@ -1138,7 +1158,9 @@ def Runner(model_name: str,
attrs.update({
'model_id': llm_config['env']['model_id_value'],
'quantize': llm_config['env']['quantize_value'],
'serialisation': first_not_none(os.environ.get('OPENLLM_SERIALIZATION'), attrs.get('serialisation'), default=llm_config['serialisation'])
'serialisation': first_not_none(os.environ.get('OPENLLM_SERIALIZATION'), attrs.get('serialisation'), default=llm_config['serialisation']),
'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None),
'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None),
})
backend = t.cast(LiteralBackend, first_not_none(backend, default=EnvVarMixin(model_name, backend=llm_config.default_backend() if llm_config is not None else 'pt')['backend_value']))
@@ -1180,24 +1202,28 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
def __call__(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
return self.generate(prompt, **attrs)
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
def generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
return self.generate(prompt, **attrs)
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
def generate_one(__self: _Runnable, prompt: str, stop: list[str], **attrs: t.Any) -> t.Sequence[dict[t.Literal['generated_text'], str]]:
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
return self.generate_one(prompt, stop, **attrs)
@bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore
def generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.Generator[str, None, str]:
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
pre = 0

View File

@@ -147,6 +147,8 @@ def construct_docker_options(llm: openllm.LLM[t.Any, t.Any],
'BENTOML_CONFIG_OPTIONS': f"'{environ['BENTOML_CONFIG_OPTIONS']}'",
}
if adapter_map: env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1')
if llm._system_message: env_dict['OPENLLM_SYSTEM_MESSAGE'] = repr(llm._system_message)
if llm._prompt_template: env_dict['OPENLLM_PROMPT_TEMPLATE'] = repr(llm._prompt_template.to_string())
# We need to handle None separately here, as env from subprocess doesn't accept None value.
_env = openllm_core.utils.EnvVarMixin(llm.config['model_name'], quantize=quantize)
@@ -212,10 +214,11 @@ def create_bento(bento_tag: bentoml.Tag,
container_version_strategy: LiteralContainerVersionStrategy = 'release',
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
_model_store: ModelStore = Provide[BentoMLContainer.model_store]) -> bentoml.Bento:
backend_envvar = llm.config['env']['backend_value']
_serialisation: LiteralSerialisation = openllm_core.utils.first_not_none(serialisation, default=llm.config['serialisation'])
labels = dict(llm.identifying_params)
labels.update({'_type': llm.llm_type, '_framework': backend_envvar, 'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle'})
labels.update({
'_type': llm.llm_type, '_framework': llm.config['env']['backend_value'], 'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle'
})
if adapter_map: labels.update(adapter_map)
if isinstance(workers_per_resource, str):
if workers_per_resource == 'round_robin': workers_per_resource = 1.0

View File

@@ -123,6 +123,8 @@ Available official model_id(s): [default: {llm_config['default_id']}]
server_timeout: int,
model_id: str | None,
model_version: str | None,
system_message: str | None,
prompt_template_file: t.IO[t.Any] | None,
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
device: t.Tuple[str, ...],
quantize: LiteralQuantise | None,
@@ -175,6 +177,7 @@ Available official model_id(s): [default: {llm_config['default_id']}]
start_env = os.environ.copy()
start_env = parse_config_options(config, server_timeout, wpr, device, cors, start_env)
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
start_env.update({
'OPENLLM_MODEL': model,
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
@@ -185,10 +188,14 @@ Available official model_id(s): [default: {llm_config['default_id']}]
})
if env['model_id_value']: start_env[env.model_id] = str(env['model_id_value'])
if env['quantize_value']: start_env[env.quantize] = str(env['quantize_value'])
if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
llm = openllm.utils.infer_auto_class(env['backend_value']).for_model(model,
model_id=start_env[env.model_id],
model_version=model_version,
prompt_template=prompt_template,
system_message=system_message,
llm_config=config,
ensure_available=True,
adapter_map=adapter_map,
@@ -233,6 +240,8 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
cog.optgroup.group('General LLM Options', help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
model_id_option(factory=cog.optgroup),
model_version_option(factory=cog.optgroup),
system_message_option(factory=cog.optgroup),
prompt_template_file_option(factory=cog.optgroup),
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
workers_per_resource_option(factory=cog.optgroup),
cors_option(factory=cog.optgroup),
@@ -377,6 +386,21 @@ def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option('--model-version', type=click.STRING, default=None, help='Optional model version to save for this model. It will be inferred automatically from model-id.', **attrs)(f)
def system_message_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option('--system-message',
type=click.STRING,
default=None,
envvar='OPENLLM_SYSTEM_MESSAGE',
help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.',
**attrs)(f)
def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option('--prompt-template-file',
type=click.File(),
default=None,
help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.',
**attrs)(f)
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
# XXX: remove the check for __args__ once we have ggml and mlc supports

View File

@@ -40,6 +40,8 @@ def _start(model_name: str,
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
device: tuple[str, ...] | t.Literal['all'] | None = None,
quantize: LiteralQuantise | None = None,
system_message: str | None = None,
prompt_template_file: str | None = None,
adapter_map: dict[LiteralString, str | None] | None = None,
backend: LiteralBackend | None = None,
additional_args: list[str] | None = None,
@@ -61,6 +63,8 @@ def _start(model_name: str,
model_name: The model name to start this LLM
model_id: Optional model id for this given LLM
timeout: The server timeout
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
workers_per_resource: Number of workers per resource assigned.
See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy)
for more information. By default, this is set to 1.
@@ -90,6 +94,8 @@ def _start(model_name: str,
args: list[str] = []
if model_id: args.extend(['--model-id', model_id])
if system_message: args.extend(['--system-message', system_message])
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
if timeout: args.extend(['--server-timeout', str(timeout)])
if workers_per_resource:
args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
@@ -113,6 +119,8 @@ def _build(model_name: str,
bento_version: str | None = None,
quantize: LiteralQuantise | None = None,
adapter_map: dict[str, str | None] | None = None,
system_message: str | None = None,
prompt_template_file: str | None = None,
build_ctx: str | None = None,
enable_features: tuple[str, ...] | None = None,
workers_per_resource: float | None = None,
@@ -137,6 +145,8 @@ def _build(model_name: str,
model_id: Optional model id for this given LLM
model_version: Optional model version for this given LLM
bento_version: Optional bento veresion for this given BentoLLM
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
quantize: Quantize the model weights. This is only applicable for PyTorch models.
Possible quantisation strategies:
- int8: Quantize the model with 8bit (bitsandbytes required)
@@ -181,6 +191,8 @@ def _build(model_name: str,
if enable_features: args.extend([f'--enable-features={f}' for f in enable_features])
if workers_per_resource: args.extend(['--workers-per-resource', str(workers_per_resource)])
if overwrite: args.append('--overwrite')
if system_message: args.extend(['--system-message', system_message])
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
if model_version: args.extend(['--model-version', model_version])
if bento_version: args.extend(['--bento-version', bento_version])

View File

@@ -101,9 +101,11 @@ from ._factory import model_id_option
from ._factory import model_name_argument
from ._factory import model_version_option
from ._factory import output_option
from ._factory import prompt_template_file_option
from ._factory import quantize_option
from ._factory import serialisation_option
from ._factory import start_command_factory
from ._factory import system_message_option
from ._factory import workers_per_resource_option
if t.TYPE_CHECKING:
@@ -426,6 +428,8 @@ def import_command(
@output_option
@machine_option
@backend_option
@system_message_option
@prompt_template_file_option
@click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.')
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
@workers_per_resource_option(factory=click, build=True)
@@ -478,6 +482,8 @@ def build_command(
adapter_id: tuple[str, ...],
build_ctx: str | None,
backend: LiteralBackend,
system_message: str | None,
prompt_template_file: t.IO[t.Any] | None,
machine: bool,
model_version: str | None,
dockerfile_template: t.TextIO | None,
@@ -514,6 +520,7 @@ def build_command(
llm_config = AutoConfig.for_model(model_name)
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
env = EnvVarMixin(model_name, backend=backend, model_id=model_id, quantize=quantize)
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
@@ -521,10 +528,10 @@ def build_command(
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': _serialisation, env.backend: env['backend_value']})
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
llm = infer_auto_class(env['backend_value']).for_model(
model_name, model_id=env['model_id_value'], llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs
)
llm = infer_auto_class(env['backend_value']).for_model(model_name, model_id=env['model_id_value'], prompt_template=prompt_template, system_message=system_message, llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs)
labels = dict(llm.identifying_params)
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})

View File

@@ -13,7 +13,7 @@ from openllm.cli import termui
from openllm.cli._factory import machine_option
from openllm.cli._factory import model_complete_envvar
from openllm.cli._factory import output_option
from openllm_core._prompt import process_prompt
from openllm_core.prompts import process_prompt
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
@@ -51,7 +51,11 @@ def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None,
_prompt_template = template(format)
else:
_prompt_template = template
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
try:
# backward-compatible. TO BE REMOVED once every model has default system message and prompt template.
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
except RuntimeError:
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(prompt, prompt_template=_prompt_template)[0]
if machine: return repr(fully_formatted)
elif output == 'porcelain': termui.echo(repr(fully_formatted), fg='white')
elif output == 'json':

View File

@@ -0,0 +1,3 @@
from __future__ import annotations
from openllm_core.prompts import PromptTemplate as PromptTemplate