mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-05 15:46:16 -05:00
feat: PromptTemplate and system prompt support (#407)
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -39,6 +39,7 @@ _import_structure: dict[str, list[str]] = {
|
||||
"bundle": [],
|
||||
"playground": [],
|
||||
"testing": [],
|
||||
"prompts": ["PromptTemplate"],
|
||||
"utils": ["infer_auto_class"],
|
||||
"serialisation": ["ggml", "transformers"],
|
||||
"cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"],
|
||||
@@ -70,6 +71,7 @@ if _t.TYPE_CHECKING:
|
||||
from .cli._sdk import build as build, import_model as import_model, list_models as list_models, start as start, start_grpc as start_grpc
|
||||
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
|
||||
from .serialisation import ggml as ggml, transformers as transformers
|
||||
from .prompts import PromptTemplate as PromptTemplate
|
||||
from .utils import infer_auto_class as infer_auto_class
|
||||
|
||||
try:
|
||||
|
||||
@@ -22,7 +22,6 @@ import openllm_core
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
from openllm_core._configuration import FineTuneConfig
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core._schema import EmbeddingsOutput
|
||||
from openllm_core._typing_compat import AdaptersMapping
|
||||
from openllm_core._typing_compat import AdaptersTuple
|
||||
@@ -40,6 +39,8 @@ from openllm_core._typing_compat import PeftAdapterOutput
|
||||
from openllm_core._typing_compat import T
|
||||
from openllm_core._typing_compat import TupleAny
|
||||
from openllm_core._typing_compat import overload
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.prompts import process_prompt
|
||||
from openllm_core.utils import DEBUG
|
||||
from openllm_core.utils import MYPY
|
||||
from openllm_core.utils import EnvVarMixin
|
||||
@@ -293,6 +294,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
model_version: t.Optional[str],
|
||||
serialisation: LiteralSerialisation,
|
||||
_local: bool,
|
||||
prompt_template: PromptTemplate | None,
|
||||
system_message: str | None,
|
||||
**attrs: t.Any) -> None:
|
||||
'''Generated __attrs_init__ for openllm.LLM.'''
|
||||
|
||||
@@ -310,6 +313,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
_model_version: str
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_prompt_template: PromptTemplate | None
|
||||
_system_message: str | None
|
||||
|
||||
def __init_subclass__(cls: type[LLM[M, T]]) -> None:
|
||||
cd = cls.__dict__
|
||||
@@ -375,6 +380,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
def from_pretrained(cls,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
llm_config: LLMConfig | None = None,
|
||||
*args: t.Any,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
@@ -421,6 +428,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
model_version: Optional version for this given model id. Default to None. This is useful for saving from custom path.
|
||||
If set to None, the version will either be the git hash from given pretrained model, or the hash inferred
|
||||
from last modified time of the given directory.
|
||||
system_message: Optional system message for what the system prompt for the specified LLM is. If not given, the default system message will be used.
|
||||
prompt_template: Optional custom prompt template. If not given, the default prompt template for the specified model will be used.
|
||||
llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM
|
||||
will use `config_class` to construct default configuration.
|
||||
quantize: The quantization to use for this LLM. Defaults to None. Possible values
|
||||
@@ -457,6 +466,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
if adapter_map is not None and not is_peft_available():
|
||||
raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'")
|
||||
if adapter_map: logger.debug('OpenLLM will apply the following adapters layers: %s', list(adapter_map))
|
||||
if isinstance(prompt_template, str): prompt_template = PromptTemplate(prompt_template)
|
||||
|
||||
if llm_config is None:
|
||||
llm_config = cls.config_class.model_construct_env(**attrs)
|
||||
@@ -476,6 +486,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
quantization_config=quantization_config,
|
||||
_quantize=quantize,
|
||||
_model_version=_tag.version,
|
||||
_prompt_template=prompt_template,
|
||||
_system_message=system_message,
|
||||
_tag=_tag,
|
||||
_serialisation=serialisation,
|
||||
_local=_local,
|
||||
@@ -538,6 +550,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
_tag: bentoml.Tag,
|
||||
_serialisation: LiteralSerialisation,
|
||||
_local: bool,
|
||||
_prompt_template: PromptTemplate | None,
|
||||
_system_message: str | None,
|
||||
_adapters_mapping: AdaptersMapping | None,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
@@ -650,7 +664,9 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
_adapters_mapping,
|
||||
_model_version,
|
||||
_serialisation,
|
||||
_local)
|
||||
_local,
|
||||
_prompt_template,
|
||||
_system_message)
|
||||
|
||||
self.llm_post_init()
|
||||
|
||||
@@ -728,6 +744,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
- The attributes dictionary that can be passed into LLMConfig to generate a GenerationConfig
|
||||
- The attributes dictionary that will be passed into `self.postprocess_generate`.
|
||||
'''
|
||||
attrs.update({'prompt_template': self._prompt_template, 'system_message': self._system_message})
|
||||
return self.config.sanitize_parameters(prompt, **attrs)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> t.Any:
|
||||
@@ -937,9 +954,10 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **attrs)
|
||||
return self.postprocess_generate(prompt, self.generate(prompt, **generate_kwargs), **postprocess_kwargs)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
def generate_one(self, prompt: str, stop: list[str], **attrs: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
prompt, generate_kwargs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
max_new_tokens, encoded_inputs = generate_kwargs.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], generate_kwargs.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
@@ -949,6 +967,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> t.List[t.Any]:
|
||||
# TODO: support different generation strategies, similar to self.model.generate
|
||||
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
for it in self.generate_iterator(prompt, **attrs):
|
||||
pass
|
||||
return [it]
|
||||
@@ -968,6 +987,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
from ._generation import is_partial_stop
|
||||
from ._generation import prepare_logits_processor
|
||||
|
||||
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
len_prompt = len(prompt)
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
if stop_token_ids is None: stop_token_ids = []
|
||||
@@ -1138,7 +1158,9 @@ def Runner(model_name: str,
|
||||
attrs.update({
|
||||
'model_id': llm_config['env']['model_id_value'],
|
||||
'quantize': llm_config['env']['quantize_value'],
|
||||
'serialisation': first_not_none(os.environ.get('OPENLLM_SERIALIZATION'), attrs.get('serialisation'), default=llm_config['serialisation'])
|
||||
'serialisation': first_not_none(os.environ.get('OPENLLM_SERIALIZATION'), attrs.get('serialisation'), default=llm_config['serialisation']),
|
||||
'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None),
|
||||
'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None),
|
||||
})
|
||||
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, default=EnvVarMixin(model_name, backend=llm_config.default_backend() if llm_config is not None else 'pt')['backend_value']))
|
||||
@@ -1180,24 +1202,28 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
|
||||
def __call__(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
|
||||
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
adapter_name = attrs.pop('adapter_name', None)
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
return self.generate(prompt, **attrs)
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
|
||||
def generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
|
||||
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
adapter_name = attrs.pop('adapter_name', None)
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
return self.generate(prompt, **attrs)
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
|
||||
def generate_one(__self: _Runnable, prompt: str, stop: list[str], **attrs: t.Any) -> t.Sequence[dict[t.Literal['generated_text'], str]]:
|
||||
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
adapter_name = attrs.pop('adapter_name', None)
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
return self.generate_one(prompt, stop, **attrs)
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore
|
||||
def generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.Generator[str, None, str]:
|
||||
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
adapter_name = attrs.pop('adapter_name', None)
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
pre = 0
|
||||
|
||||
@@ -147,6 +147,8 @@ def construct_docker_options(llm: openllm.LLM[t.Any, t.Any],
|
||||
'BENTOML_CONFIG_OPTIONS': f"'{environ['BENTOML_CONFIG_OPTIONS']}'",
|
||||
}
|
||||
if adapter_map: env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
if llm._system_message: env_dict['OPENLLM_SYSTEM_MESSAGE'] = repr(llm._system_message)
|
||||
if llm._prompt_template: env_dict['OPENLLM_PROMPT_TEMPLATE'] = repr(llm._prompt_template.to_string())
|
||||
|
||||
# We need to handle None separately here, as env from subprocess doesn't accept None value.
|
||||
_env = openllm_core.utils.EnvVarMixin(llm.config['model_name'], quantize=quantize)
|
||||
@@ -212,10 +214,11 @@ def create_bento(bento_tag: bentoml.Tag,
|
||||
container_version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store]) -> bentoml.Bento:
|
||||
backend_envvar = llm.config['env']['backend_value']
|
||||
_serialisation: LiteralSerialisation = openllm_core.utils.first_not_none(serialisation, default=llm.config['serialisation'])
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({'_type': llm.llm_type, '_framework': backend_envvar, 'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle'})
|
||||
labels.update({
|
||||
'_type': llm.llm_type, '_framework': llm.config['env']['backend_value'], 'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle'
|
||||
})
|
||||
if adapter_map: labels.update(adapter_map)
|
||||
if isinstance(workers_per_resource, str):
|
||||
if workers_per_resource == 'round_robin': workers_per_resource = 1.0
|
||||
|
||||
@@ -123,6 +123,8 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
server_timeout: int,
|
||||
model_id: str | None,
|
||||
model_version: str | None,
|
||||
system_message: str | None,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: LiteralQuantise | None,
|
||||
@@ -175,6 +177,7 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
start_env = os.environ.copy()
|
||||
start_env = parse_config_options(config, server_timeout, wpr, device, cors, start_env)
|
||||
|
||||
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
|
||||
start_env.update({
|
||||
'OPENLLM_MODEL': model,
|
||||
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
|
||||
@@ -185,10 +188,14 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
})
|
||||
if env['model_id_value']: start_env[env.model_id] = str(env['model_id_value'])
|
||||
if env['quantize_value']: start_env[env.quantize] = str(env['quantize_value'])
|
||||
if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
|
||||
llm = openllm.utils.infer_auto_class(env['backend_value']).for_model(model,
|
||||
model_id=start_env[env.model_id],
|
||||
model_version=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_config=config,
|
||||
ensure_available=True,
|
||||
adapter_map=adapter_map,
|
||||
@@ -233,6 +240,8 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
|
||||
cog.optgroup.group('General LLM Options', help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
|
||||
model_id_option(factory=cog.optgroup),
|
||||
model_version_option(factory=cog.optgroup),
|
||||
system_message_option(factory=cog.optgroup),
|
||||
prompt_template_file_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
cors_option(factory=cog.optgroup),
|
||||
@@ -377,6 +386,21 @@ def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--model-version', type=click.STRING, default=None, help='Optional model version to save for this model. It will be inferred automatically from model-id.', **attrs)(f)
|
||||
|
||||
def system_message_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--system-message',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_SYSTEM_MESSAGE',
|
||||
help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.',
|
||||
**attrs)(f)
|
||||
|
||||
def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--prompt-template-file',
|
||||
type=click.File(),
|
||||
default=None,
|
||||
help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.',
|
||||
**attrs)(f)
|
||||
|
||||
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
|
||||
# XXX: remove the check for __args__ once we have ggml and mlc supports
|
||||
|
||||
@@ -40,6 +40,8 @@ def _start(model_name: str,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
@@ -61,6 +63,8 @@ def _start(model_name: str,
|
||||
model_name: The model name to start this LLM
|
||||
model_id: Optional model id for this given LLM
|
||||
timeout: The server timeout
|
||||
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
|
||||
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
|
||||
workers_per_resource: Number of workers per resource assigned.
|
||||
See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy)
|
||||
for more information. By default, this is set to 1.
|
||||
@@ -90,6 +94,8 @@ def _start(model_name: str,
|
||||
|
||||
args: list[str] = []
|
||||
if model_id: args.extend(['--model-id', model_id])
|
||||
if system_message: args.extend(['--system-message', system_message])
|
||||
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if timeout: args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource:
|
||||
args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
|
||||
@@ -113,6 +119,8 @@ def _build(model_name: str,
|
||||
bento_version: str | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
workers_per_resource: float | None = None,
|
||||
@@ -137,6 +145,8 @@ def _build(model_name: str,
|
||||
model_id: Optional model id for this given LLM
|
||||
model_version: Optional model version for this given LLM
|
||||
bento_version: Optional bento veresion for this given BentoLLM
|
||||
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
|
||||
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
|
||||
quantize: Quantize the model weights. This is only applicable for PyTorch models.
|
||||
Possible quantisation strategies:
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
@@ -181,6 +191,8 @@ def _build(model_name: str,
|
||||
if enable_features: args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if workers_per_resource: args.extend(['--workers-per-resource', str(workers_per_resource)])
|
||||
if overwrite: args.append('--overwrite')
|
||||
if system_message: args.extend(['--system-message', system_message])
|
||||
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version: args.extend(['--model-version', model_version])
|
||||
if bento_version: args.extend(['--bento-version', bento_version])
|
||||
|
||||
@@ -101,9 +101,11 @@ from ._factory import model_id_option
|
||||
from ._factory import model_name_argument
|
||||
from ._factory import model_version_option
|
||||
from ._factory import output_option
|
||||
from ._factory import prompt_template_file_option
|
||||
from ._factory import quantize_option
|
||||
from ._factory import serialisation_option
|
||||
from ._factory import start_command_factory
|
||||
from ._factory import system_message_option
|
||||
from ._factory import workers_per_resource_option
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -426,6 +428,8 @@ def import_command(
|
||||
@output_option
|
||||
@machine_option
|
||||
@backend_option
|
||||
@system_message_option
|
||||
@prompt_template_file_option
|
||||
@click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.')
|
||||
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
|
||||
@workers_per_resource_option(factory=click, build=True)
|
||||
@@ -478,6 +482,8 @@ def build_command(
|
||||
adapter_id: tuple[str, ...],
|
||||
build_ctx: str | None,
|
||||
backend: LiteralBackend,
|
||||
system_message: str | None,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
machine: bool,
|
||||
model_version: str | None,
|
||||
dockerfile_template: t.TextIO | None,
|
||||
@@ -514,6 +520,7 @@ def build_command(
|
||||
llm_config = AutoConfig.for_model(model_name)
|
||||
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
|
||||
env = EnvVarMixin(model_name, backend=backend, model_id=model_id, quantize=quantize)
|
||||
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
|
||||
|
||||
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
|
||||
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
|
||||
@@ -521,10 +528,10 @@ def build_command(
|
||||
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': _serialisation, env.backend: env['backend_value']})
|
||||
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
|
||||
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
|
||||
if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
|
||||
llm = infer_auto_class(env['backend_value']).for_model(
|
||||
model_name, model_id=env['model_id_value'], llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs
|
||||
)
|
||||
llm = infer_auto_class(env['backend_value']).for_model(model_name, model_id=env['model_id_value'], prompt_template=prompt_template, system_message=system_message, llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs)
|
||||
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})
|
||||
|
||||
@@ -13,7 +13,7 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import machine_option
|
||||
from openllm.cli._factory import model_complete_envvar
|
||||
from openllm.cli._factory import output_option
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
|
||||
@@ -51,7 +51,11 @@ def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None,
|
||||
_prompt_template = template(format)
|
||||
else:
|
||||
_prompt_template = template
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
try:
|
||||
# backward-compatible. TO BE REMOVED once every model has default system message and prompt template.
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
except RuntimeError:
|
||||
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(prompt, prompt_template=_prompt_template)[0]
|
||||
if machine: return repr(fully_formatted)
|
||||
elif output == 'porcelain': termui.echo(repr(fully_formatted), fg='white')
|
||||
elif output == 'json':
|
||||
|
||||
3
openllm-python/src/openllm/prompts.py
Normal file
3
openllm-python/src/openllm/prompts.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openllm_core.prompts import PromptTemplate as PromptTemplate
|
||||
Reference in New Issue
Block a user