mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-04 15:16:03 -05:00
chore: cleanup unused prompt templates (#713)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -77,8 +77,6 @@ def Runner(
|
||||
'serialisation': first_not_none(
|
||||
attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']
|
||||
),
|
||||
'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None),
|
||||
'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ from openllm_core._typing_compat import (
|
||||
TupleAny,
|
||||
)
|
||||
from openllm_core.exceptions import MissingDependencyError
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.utils import (
|
||||
DEBUG,
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
@@ -129,8 +128,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_max_model_len: int | None
|
||||
_prompt_template: PromptTemplate | None
|
||||
_system_message: str | None
|
||||
|
||||
__llm_dtype__: LiteralDtype | t.Literal['auto', 'half', 'float'] = 'auto'
|
||||
__llm_torch_dtype__: 'torch.dtype' = None
|
||||
@@ -148,8 +145,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
model_id,
|
||||
model_version=None,
|
||||
model_tag=None,
|
||||
prompt_template=None,
|
||||
system_message=None,
|
||||
llm_config=None,
|
||||
backend=None,
|
||||
*args,
|
||||
@@ -192,8 +187,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
serialisation=serialisation,
|
||||
local=_local,
|
||||
max_model_len=max_model_len,
|
||||
prompt_template=PromptTemplate(prompt_template) if isinstance(prompt_template, str) else prompt_template,
|
||||
system_message=system_message,
|
||||
LLM__model_attrs=model_attrs,
|
||||
LLM__tokenizer_attrs=tokenizer_attrs,
|
||||
llm_dtype__=dtype.lower(),
|
||||
@@ -480,16 +473,24 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
|
||||
request_id = gen_random_uuid() if request_id is None else request_id
|
||||
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
|
||||
async for out in self.runner.generate_iterator.async_stream(
|
||||
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
|
||||
):
|
||||
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
|
||||
delta_outputs = [None] * len(generated.outputs)
|
||||
if generated.finished:
|
||||
break
|
||||
for output in generated.outputs:
|
||||
i = output.index
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
|
||||
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
|
||||
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
|
||||
yield generated.with_options(outputs=delta_outputs)
|
||||
try:
|
||||
generator = self.runner.generate_iterator.async_stream(
|
||||
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
|
||||
)
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Failed to start generation task: {err}') from err
|
||||
|
||||
try:
|
||||
async for out in generator:
|
||||
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
|
||||
delta_outputs = [None] * len(generated.outputs)
|
||||
if generated.finished:
|
||||
break
|
||||
for output in generated.outputs:
|
||||
i = output.index
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
|
||||
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
|
||||
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
|
||||
yield generated.with_options(outputs=delta_outputs)
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Exception caught during generation: {err}') from err
|
||||
|
||||
@@ -18,7 +18,6 @@ from openllm_core._typing_compat import (
|
||||
M,
|
||||
T,
|
||||
)
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.utils.representation import ReprArgs
|
||||
|
||||
from ._quantisation import QuantizationConfig
|
||||
@@ -48,8 +47,6 @@ class LLM(Generic[M, T]):
|
||||
_adapter_map: Optional[AdapterMap]
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_prompt_template: Optional[PromptTemplate]
|
||||
_system_message: Optional[str]
|
||||
|
||||
__llm_dtype__: Dtype = ...
|
||||
__llm_torch_dtype__: Optional[torch.dtype] = ...
|
||||
@@ -74,8 +71,6 @@ class LLM(Generic[M, T]):
|
||||
model_id: str,
|
||||
model_version: Optional[str] = ...,
|
||||
model_tag: Optional[Union[str, Tag]] = ...,
|
||||
prompt_template: Optional[Union[str, PromptTemplate]] = ...,
|
||||
system_message: Optional[str] = ...,
|
||||
llm_config: Optional[LLMConfig] = ...,
|
||||
backend: Optional[LiteralBackend] = ...,
|
||||
*args: Any,
|
||||
|
||||
@@ -27,7 +27,7 @@ def registry(cls=None, *, alias=None):
|
||||
return decorator(cls)
|
||||
|
||||
|
||||
def runner(llm):
|
||||
def runner(llm: openllm.LLM):
|
||||
from ._strategies import CascadingResourceStrategy
|
||||
|
||||
try:
|
||||
@@ -35,19 +35,6 @@ def runner(llm):
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err
|
||||
|
||||
if llm._prompt_template:
|
||||
prompt_template = llm._prompt_template.to_string()
|
||||
elif hasattr(llm.config, 'default_prompt_template'):
|
||||
prompt_template = llm.config.default_prompt_template
|
||||
else:
|
||||
prompt_template = None
|
||||
if llm._system_message:
|
||||
system_message = llm._system_message
|
||||
elif hasattr(llm.config, 'default_system_message'):
|
||||
system_message = llm.config.default_system_message
|
||||
else:
|
||||
system_message = None
|
||||
|
||||
return types.new_class(
|
||||
llm.config.__class__.__name__[:-6] + 'Runner',
|
||||
(bentoml.Runner,),
|
||||
@@ -80,8 +67,8 @@ def runner(llm):
|
||||
('llm_tag', llm.tag),
|
||||
),
|
||||
'has_adapters': llm.has_adapters,
|
||||
'prompt_template': prompt_template,
|
||||
'system_message': system_message,
|
||||
'template': llm.config.template,
|
||||
'system_message': llm.config.system_message,
|
||||
}
|
||||
),
|
||||
)(
|
||||
@@ -101,7 +88,7 @@ class CTranslateRunnable(bentoml.Runnable):
|
||||
def __init__(self, llm):
|
||||
if not is_ctranslate_available():
|
||||
raise OpenLLMException('ctranslate is not installed. Please install it with `pip install "openllm[ctranslate]"`')
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.modle, llm.tokenizer
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
|
||||
@@ -83,8 +83,8 @@ class Runner(Protocol[Mo, To]):
|
||||
config: LLMConfig = ...
|
||||
backend: LiteralBackend = ...
|
||||
has_adapters: bool = ...
|
||||
prompt_template: Optional[str] = ...
|
||||
system_message: Optional[str] = ...
|
||||
template: str = ...
|
||||
system_message: str = ...
|
||||
|
||||
class generate_iterator(RunnerMethod[List[int], AsyncGenerator[str, None]]):
|
||||
@staticmethod
|
||||
|
||||
@@ -13,8 +13,6 @@ logger = logging.getLogger(__name__)
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=svars.model_id,
|
||||
model_tag=svars.model_tag,
|
||||
prompt_template=svars.prompt_template,
|
||||
system_message=svars.system_message,
|
||||
serialisation=svars.serialization,
|
||||
adapter_map=svars.adapter_map,
|
||||
trust_remote_code=svars.trust_remote_code,
|
||||
@@ -50,8 +48,6 @@ _Metadata = openllm.MetadataOutput(
|
||||
backend=llm.__llm_backend__,
|
||||
model_id=llm.model_id,
|
||||
configuration=llm.config.model_dump_json().decode(),
|
||||
prompt_template=llm.runner.prompt_template,
|
||||
system_message=llm.runner.system_message,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,5 @@ from openllm_core.utils import ENV_VARS_TRUE_VALUES
|
||||
model_id = os.environ['OPENLLM_MODEL_ID']
|
||||
model_tag = None
|
||||
adapter_map = orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None)))
|
||||
prompt_template = os.getenv('OPENLLM_PROMPT_TEMPLATE')
|
||||
system_message = os.getenv('OPENLLM_SYSTEM_MESSAGE')
|
||||
serialization = os.getenv('OPENLLM_SERIALIZATION', default='safetensors')
|
||||
trust_remote_code = str(os.getenv('TRUST_REMOTE_CODE', default=str(False))).upper() in ENV_VARS_TRUE_VALUES
|
||||
|
||||
@@ -5,7 +5,5 @@ from openllm_core._typing_compat import LiteralSerialisation
|
||||
model_id: str = ...
|
||||
model_tag: Optional[str] = ...
|
||||
adapter_map: Optional[Dict[str, str]] = ...
|
||||
prompt_template: Optional[str] = ...
|
||||
system_message: Optional[str] = ...
|
||||
serialization: LiteralSerialisation = ...
|
||||
trust_remote_code: bool = ...
|
||||
|
||||
@@ -4,6 +4,4 @@ model_id = '{__model_id__}' # openllm: model id
|
||||
model_tag = '{__model_tag__}' # openllm: model tag
|
||||
adapter_map = orjson.loads("""{__model_adapter_map__}""") # openllm: model adapter map
|
||||
serialization = '{__model_serialization__}' # openllm: model serialization
|
||||
prompt_template = {__model_prompt_template__} # openllm: model prompt template
|
||||
system_message = {__model_system_message__} # openllm: model system message
|
||||
trust_remote_code = {__model_trust_remote_code__} # openllm: model trust remote code
|
||||
|
||||
@@ -83,8 +83,6 @@ def construct_docker_options(
|
||||
None,
|
||||
llm._serialisation,
|
||||
llm,
|
||||
llm._system_message,
|
||||
llm._prompt_template,
|
||||
use_current_env=False,
|
||||
)
|
||||
# XXX: We need to quote this so that the envvar in container recognize as valid json
|
||||
@@ -101,8 +99,6 @@ def construct_docker_options(
|
||||
OPENLLM_MODEL_ID = '# openllm: model id'
|
||||
OPENLLM_MODEL_TAG = '# openllm: model tag'
|
||||
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
|
||||
OPENLLM_MODEL_PROMPT_TEMPLATE = '# openllm: model prompt template'
|
||||
OPENLLM_MODEL_SYSTEM_MESSAGE = '# openllm: model system message'
|
||||
OPENLLM_MODEL_SERIALIZATION = '# openllm: model serialization'
|
||||
OPENLLM_MODEL_TRUST_REMOTE_CODE = '# openllm: model trust remote code'
|
||||
|
||||
@@ -140,16 +136,6 @@ class ModelAdapterMapFormatter(_ServiceVarsFormatter):
|
||||
identifier = OPENLLM_MODEL_ADAPTER_MAP
|
||||
|
||||
|
||||
class ModelPromptTemplateFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_prompt_template__'
|
||||
identifier = OPENLLM_MODEL_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
class ModelSystemMessageFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_system_message__'
|
||||
identifier = OPENLLM_MODEL_SYSTEM_MESSAGE
|
||||
|
||||
|
||||
class ModelSerializationFormatter(_ServiceVarsFormatter):
|
||||
keyword = '__model_serialization__'
|
||||
identifier = OPENLLM_MODEL_SERIALIZATION
|
||||
@@ -187,16 +173,6 @@ def write_service(llm, llm_fs, adapter_map):
|
||||
src_contents[i] = serialization_formatter.parse_line(it)
|
||||
elif trust_remote_code_formatter.identifier in it:
|
||||
src_contents[i] = trust_remote_code_formatter.parse_line(it)
|
||||
elif OPENLLM_MODEL_PROMPT_TEMPLATE in it:
|
||||
if llm._prompt_template:
|
||||
src_contents[i] = ModelPromptTemplateFormatter(f'"""{llm._prompt_template.to_string()}"""').parse_line(it)
|
||||
else:
|
||||
src_contents[i] = ModelPromptTemplateFormatter(str(None)).parse_line(it)
|
||||
elif OPENLLM_MODEL_SYSTEM_MESSAGE in it:
|
||||
if llm._system_message:
|
||||
src_contents[i] = ModelSystemMessageFormatter(f'"""{llm._system_message}"""').parse_line(it)
|
||||
else:
|
||||
src_contents[i] = ModelSystemMessageFormatter(str(None)).parse_line(it)
|
||||
|
||||
script = f"# GENERATED BY 'openllm build {llm.model_id}'. DO NOT EDIT\n\n" + ''.join(src_contents)
|
||||
if SHOW_CODEGEN:
|
||||
|
||||
@@ -105,7 +105,7 @@ async def cohere_generate(req, llm):
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
request_id = gen_random_uuid('cohere-generate')
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
if request.prompt_vars is not None:
|
||||
prompt = request.prompt.format(**request.prompt_vars)
|
||||
@@ -202,7 +202,7 @@ async def cohere_chat(req, llm):
|
||||
_transpile_cohere_chat_messages(request), tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
|
||||
@@ -156,7 +156,7 @@ async def chat_completions(req, llm):
|
||||
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
@@ -272,14 +272,9 @@ async def completions(req, llm):
|
||||
prompt = request.prompt
|
||||
# TODO: Support multiple prompts
|
||||
|
||||
if request.logprobs is not None and llm.__llm_backend__ == 'pt': # TODO: support logprobs generation for PyTorch
|
||||
return error_response(
|
||||
HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`."
|
||||
)
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('cmpl')
|
||||
created_time = int(time.monotonic())
|
||||
config = llm.config.with_request(request)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
|
||||
@@ -135,13 +135,11 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
|
||||
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
|
||||
composed = openllm.utils.compose(
|
||||
_OpenLLM_GenericInternalConfig().to_click_options,
|
||||
_OpenLLM_GenericInternalConfig.parse,
|
||||
_http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'),
|
||||
dtype_option(factory=cog.optgroup),
|
||||
model_version_option(factory=cog.optgroup),
|
||||
system_message_option(factory=cog.optgroup),
|
||||
prompt_template_file_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
cors_option(factory=cog.optgroup),
|
||||
@@ -318,27 +316,6 @@ def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
)(f)
|
||||
|
||||
|
||||
def system_message_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--system-message',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_SYSTEM_MESSAGE',
|
||||
help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--prompt-template-file',
|
||||
type=click.File(),
|
||||
default=None,
|
||||
help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--backend',
|
||||
|
||||
@@ -37,8 +37,6 @@ def _start(
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
@@ -60,8 +58,6 @@ def _start(
|
||||
Args:
|
||||
model_id: The model id to start this LLMServer
|
||||
timeout: The server timeout
|
||||
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
|
||||
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
|
||||
workers_per_resource: Number of workers per resource assigned.
|
||||
See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy)
|
||||
for more information. By default, this is set to 1.
|
||||
@@ -88,10 +84,6 @@ def _start(
|
||||
os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
|
||||
|
||||
args: list[str] = [model_id]
|
||||
if system_message:
|
||||
args.extend(['--system-message', system_message])
|
||||
if prompt_template_file:
|
||||
args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if timeout:
|
||||
args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource:
|
||||
@@ -129,8 +121,6 @@ def _build(
|
||||
bento_version: str | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
dockerfile_template: str | None = None,
|
||||
@@ -155,8 +145,6 @@ def _build(
|
||||
model_id: The model id to build this BentoLLM
|
||||
model_version: Optional model version for this given LLM
|
||||
bento_version: Optional bento veresion for this given BentoLLM
|
||||
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
|
||||
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
|
||||
quantize: Quantize the model weights. This is only applicable for PyTorch models.
|
||||
Possible quantisation strategies:
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
@@ -209,10 +197,6 @@ def _build(
|
||||
args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if overwrite:
|
||||
args.append('--overwrite')
|
||||
if system_message:
|
||||
args.extend(['--system-message', system_message])
|
||||
if prompt_template_file:
|
||||
args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if adapter_map:
|
||||
args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version:
|
||||
|
||||
@@ -100,11 +100,9 @@ from ._factory import (
|
||||
model_name_argument,
|
||||
model_version_option,
|
||||
parse_config_options,
|
||||
prompt_template_file_option,
|
||||
quantize_option,
|
||||
serialisation_option,
|
||||
start_decorator,
|
||||
system_message_option,
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -404,8 +402,6 @@ def start_command(
|
||||
model_id: str,
|
||||
server_timeout: int,
|
||||
model_version: str | None,
|
||||
system_message: str | None,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: LiteralQuantise | None,
|
||||
@@ -437,7 +433,6 @@ def start_command(
|
||||
)
|
||||
|
||||
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
|
||||
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
|
||||
|
||||
from openllm.serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
@@ -467,8 +462,6 @@ def start_command(
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
backend=backend,
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
@@ -495,8 +488,6 @@ def start_command(
|
||||
adapter_map,
|
||||
serialisation,
|
||||
llm,
|
||||
system_message,
|
||||
prompt_template,
|
||||
)
|
||||
|
||||
server = bentoml.HTTPServer('_service:svc', **server_attrs)
|
||||
@@ -541,8 +532,6 @@ def start_grpc_command(
|
||||
model_id: str,
|
||||
server_timeout: int,
|
||||
model_version: str | None,
|
||||
system_message: str | None,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: LiteralQuantise | None,
|
||||
@@ -577,7 +566,6 @@ def start_grpc_command(
|
||||
)
|
||||
|
||||
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
|
||||
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
|
||||
|
||||
from openllm.serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
@@ -604,8 +592,6 @@ def start_grpc_command(
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
backend=backend,
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
@@ -634,8 +620,6 @@ def start_grpc_command(
|
||||
adapter_map,
|
||||
serialisation,
|
||||
llm,
|
||||
system_message,
|
||||
prompt_template,
|
||||
)
|
||||
|
||||
server = bentoml.GrpcServer('_service:svc', **server_attrs)
|
||||
@@ -654,18 +638,7 @@ def start_grpc_command(
|
||||
|
||||
|
||||
def process_environ(
|
||||
config,
|
||||
server_timeout,
|
||||
wpr,
|
||||
device,
|
||||
cors,
|
||||
model_id,
|
||||
adapter_map,
|
||||
serialisation,
|
||||
llm,
|
||||
system_message,
|
||||
prompt_template,
|
||||
use_current_env=True,
|
||||
config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, use_current_env=True
|
||||
) -> t.Dict[str, t.Any]:
|
||||
environ = parse_config_options(
|
||||
config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {}
|
||||
@@ -685,10 +658,6 @@ def process_environ(
|
||||
)
|
||||
if llm.quantise:
|
||||
environ['QUANTIZE'] = str(llm.quantise)
|
||||
if system_message:
|
||||
environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template:
|
||||
environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
return environ
|
||||
|
||||
|
||||
@@ -929,8 +898,6 @@ class BuildBentoOutput(t.TypedDict):
|
||||
)
|
||||
@dtype_option
|
||||
@backend_option
|
||||
@system_message_option
|
||||
@prompt_template_file_option
|
||||
@click.option(
|
||||
'--bento-version',
|
||||
type=str,
|
||||
@@ -1004,8 +971,6 @@ def build_command(
|
||||
adapter_id: tuple[str, ...],
|
||||
build_ctx: str | None,
|
||||
backend: LiteralBackend | None,
|
||||
system_message: str | None,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
model_version: str | None,
|
||||
dockerfile_template: t.TextIO | None,
|
||||
containerize: bool,
|
||||
@@ -1051,12 +1016,9 @@ def build_command(
|
||||
|
||||
state = ItemState.NOT_FOUND
|
||||
|
||||
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
backend=backend,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
@@ -1075,19 +1037,7 @@ def build_command(
|
||||
llm._tag = model.tag
|
||||
|
||||
os.environ.update(
|
||||
**process_environ(
|
||||
llm.config,
|
||||
llm.config['timeout'],
|
||||
1.0,
|
||||
None,
|
||||
True,
|
||||
llm.model_id,
|
||||
None,
|
||||
llm._serialisation,
|
||||
llm,
|
||||
llm._system_message,
|
||||
llm._prompt_template,
|
||||
)
|
||||
**process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm)
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,33 +1,82 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import traceback
|
||||
import string
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
from bentoml_cli.utils import opt_callback
|
||||
|
||||
import openllm
|
||||
import openllm_core
|
||||
from openllm_cli import termui
|
||||
from openllm_cli._factory import model_complete_envvar
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptFormatter(string.Formatter):
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0:
|
||||
raise ValueError('Positional arguments are not supported')
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(
|
||||
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
|
||||
) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras:
|
||||
raise KeyError(f'Extra params passed: {extras}')
|
||||
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]:
|
||||
return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
|
||||
|
||||
default_formatter = PromptFormatter()
|
||||
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class PromptTemplate:
|
||||
template: str
|
||||
_input_variables: t.Sequence[str] = attr.field(init=False)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
self._input_variables = default_formatter.extract_template_variables(self.template)
|
||||
|
||||
def with_options(self, **attrs: t.Any) -> PromptTemplate:
|
||||
prompt_variables = {key: '{' + key + '}' if key not in attrs else attrs[key] for key in self._input_variables}
|
||||
o = attr.evolve(self, template=self.template.format(**prompt_variables))
|
||||
_object_setattr(o, '_input_variables', default_formatter.extract_template_variables(o.template))
|
||||
return o
|
||||
|
||||
def format(self, **attrs: t.Any) -> str:
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in self._input_variables}
|
||||
try:
|
||||
return self.template.format(**prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
|
||||
) from None
|
||||
|
||||
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument(
|
||||
'model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]),
|
||||
shell_complete=model_complete_envvar,
|
||||
)
|
||||
@click.argument('model_id', shell_complete=model_complete_envvar)
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@click.option('--format', type=click.STRING, default=None)
|
||||
@click.option('--prompt-template-file', type=click.File(), default=None)
|
||||
@click.option('--chat-template-file', type=click.File(), default=None)
|
||||
@click.option('--system-message', type=str, default=None)
|
||||
@click.option(
|
||||
'--add-generation-prompt/--no-add-generation-prompt',
|
||||
default=False,
|
||||
help='See https://huggingface.co/docs/transformers/main/chat_templating#what-template-should-i-use. This only applicable if model-id is a HF model_id',
|
||||
)
|
||||
@click.option(
|
||||
'--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
|
||||
help="Define additional prompt variables. (format: ``--opt system_message='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
@@ -35,45 +84,96 @@ logger = logging.getLogger(__name__)
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(
|
||||
ctx: click.Context, /, model_name: str, prompt: str, format: str | None, _memoized: dict[str, t.Any], **_: t.Any
|
||||
ctx: click.Context,
|
||||
/,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
chat_template_file: t.IO[t.Any] | None,
|
||||
system_message: str | None,
|
||||
add_generation_prompt: bool,
|
||||
_memoized: dict[str, t.Any],
|
||||
**_: t.Any,
|
||||
) -> str | None:
|
||||
"""Get the default prompt used by OpenLLM."""
|
||||
module = getattr(openllm_core.config, f'configuration_{model_name}')
|
||||
"""Helpers for generating prompts.
|
||||
|
||||
\b
|
||||
It accepts remote HF model_ids as well as model name passed to `openllm start`.
|
||||
|
||||
If you pass in a HF model_id, then it will use the tokenizer to generate the prompt.
|
||||
|
||||
```bash
|
||||
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there"
|
||||
```
|
||||
|
||||
If you need change the prompt template, you can create the template file that contains the jina2 template through `--chat-template-file`
|
||||
See https://huggingface.co/docs/transformers/main/chat_templating#templates-for-chat-models for more details.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there" --chat-template-file template.jinja2
|
||||
```
|
||||
|
||||
\b
|
||||
|
||||
If you pass a model name, then it will use OpenLLM configuration to generate the prompt.
|
||||
Note that this is mainly for utilities, as OpenLLM won't use these prompts to format for you.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm get-prompt mistral "Hello there"
|
||||
"""
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
try:
|
||||
template = getattr(module, 'DEFAULT_PROMPT_TEMPLATE', None)
|
||||
prompt_mapping = getattr(module, 'PROMPT_MAPPING', None)
|
||||
if template is None:
|
||||
raise click.BadArgumentUsage(f'model {model_name} does not have a default prompt template') from None
|
||||
if callable(template):
|
||||
if format is None:
|
||||
if not hasattr(module, 'PROMPT_MAPPING') or module.PROMPT_MAPPING is None:
|
||||
raise RuntimeError('Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.')
|
||||
raise click.BadOptionUsage(
|
||||
'format',
|
||||
f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})",
|
||||
)
|
||||
if prompt_mapping is None:
|
||||
raise click.BadArgumentUsage(
|
||||
f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.'
|
||||
) from None
|
||||
if format not in prompt_mapping:
|
||||
raise click.BadOptionUsage(
|
||||
'format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})'
|
||||
)
|
||||
_prompt_template = template(format)
|
||||
else:
|
||||
_prompt_template = template
|
||||
|
||||
if prompt_template_file and chat_template_file:
|
||||
ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.')
|
||||
|
||||
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(
|
||||
inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()
|
||||
)
|
||||
if model_id in acceptable:
|
||||
logger.warning(
|
||||
'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n'
|
||||
)
|
||||
config = openllm.AutoConfig.for_model(model_id)
|
||||
template = prompt_template_file.read() if prompt_template_file is not None else config.template
|
||||
system_message = system_message or config.system_message
|
||||
|
||||
try:
|
||||
# backward-compatible. TO BE REMOVED once every model has default system message and prompt template.
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
formatted = (
|
||||
PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
|
||||
)
|
||||
except RuntimeError as err:
|
||||
logger.debug('Exception caught while formatting prompt: %s', err)
|
||||
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(
|
||||
prompt, prompt_template=_prompt_template
|
||||
)[0]
|
||||
termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise click.ClickException(f'Failed to determine a default prompt template for {model_name}.') from err
|
||||
ctx.fail(str(err))
|
||||
else:
|
||||
import transformers
|
||||
|
||||
trust_remote_code = openllm.utils.check_bool_env('TRUST_REMOTE_CODE', False)
|
||||
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
if chat_template_file is not None:
|
||||
chat_template_file = chat_template_file.read()
|
||||
if system_message is None:
|
||||
logger.warning('system-message is not provided, using default infer from the model architecture.\n')
|
||||
for architecture in config.architectures:
|
||||
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
|
||||
system_message = (
|
||||
openllm.AutoConfig.infer_class_from_name(
|
||||
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
|
||||
)
|
||||
.model_construct_env()
|
||||
.system_message
|
||||
)
|
||||
break
|
||||
else:
|
||||
ctx.fail(
|
||||
f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message'
|
||||
)
|
||||
messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}]
|
||||
formatted = tokenizer.apply_chat_template(
|
||||
messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False
|
||||
)
|
||||
|
||||
termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -171,7 +171,7 @@ def test_click_conversion(gen_settings: ModelSettings):
|
||||
return attrs
|
||||
|
||||
cl_ = make_llm_config('ClickConversionLLM', gen_settings)
|
||||
wrapped = cl_.to_click_options(cli_mock)
|
||||
wrapped = cl_.parse(cli_mock)
|
||||
filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union}
|
||||
click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith('fake_')]
|
||||
assert len(filtered) == len(click_options_filtered)
|
||||
|
||||
Reference in New Issue
Block a user