feat: revision parsed via model_id (#126)

This commit is contained in:
Aaron Pham
2023-07-20 14:36:53 -04:00
committed by GitHub
parent a056365d48
commit 858c2007c3
7 changed files with 83 additions and 16 deletions

View File

@@ -183,7 +183,7 @@ def make_tag(
)
return bentoml.Tag.from_taglike(
f"{model_name if in_docker() and os.getenv('BENTO_PATH') is not None else implementation + '-' + model_name}:{model_version}".strip()
f"{model_name if in_docker() and os.getenv('BENTO_PATH') is not None else implementation + '-' + model_name}:{model_version}".strip().lower()
)
@@ -671,6 +671,16 @@ class LLM(LLMInterface[M, T], ReprMixin):
if runtime is None:
runtime = cfg_cls.__openllm_runtime__
model_id, *maybe_revision = model_id.rsplit(":")
if len(maybe_revision) > 0:
if model_version is not None:
logger.warning(
"revision is specified within 'model_id' (%s), which will override the 'model_version=%s'",
maybe_revision[0],
model_version,
)
model_version = maybe_revision[0]
# quantization setup
if quantization_config and quantize:
raise ValueError(
@@ -728,7 +738,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
def _infer_tag_from_model_id(cls, model_id: str, model_version: str | None) -> bentoml.Tag:
try:
return bentoml.Tag.from_taglike(model_id)
except ValueError:
except (ValueError, bentoml.exceptions.BentoMLException):
return make_tag(
model_id,
model_version=model_version,

View File

@@ -70,6 +70,7 @@ from bentoml._internal.models.model import ModelStore
from .__about__ import __version__
from .exceptions import OpenLLMException
from .utils import DEBUG
from .utils import ENV_VARS_TRUE_VALUES
from .utils import EnvVarMixin
from .utils import LazyLoader
from .utils import LazyType
@@ -173,7 +174,7 @@ def _echo(text: t.Any, fg: str = "green", _with_style: bool = True, **attrs: t.A
call(text, **attrs)
output_option: t.Callable[[FC], FC] = click.option(
output_option: t.Callable[[_AnyCallable], _AnyCallable] = click.option(
"-o",
"--output",
type=click.Choice(["json", "pretty", "porcelain"]),
@@ -1048,6 +1049,12 @@ def start_model(
return_process: bool,
**attrs: t.Any,
) -> openllm.LLMConfig | subprocess.Popen[bytes]:
if serialisation_format == "safetensors" and quantize is not None:
if os.getenv("OPENLLM_SERIALIZATION_WARNING", str(True)).upper() in ENV_VARS_TRUE_VALUES:
_echo(
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. Use with caution!. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=True\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.",
fg="yellow",
)
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
config, server_attrs = llm_config.model_validate_click(**attrs)
@@ -2373,20 +2380,21 @@ def query(
else openllm.client.GrpcClient(endpoint, timeout=timeout)
)
input_fg = "yellow"
input_fg = "magenta"
generated_fg = "cyan"
if output != "porcelain":
_echo("Input prompt: ", nl=False, fg="white")
_echo(f"{prompt}", fg="magenta", nl=False)
_echo(f"{prompt}", fg=input_fg, nl=False)
res = client.query(prompt, return_raw_response=True)
if output == "pretty":
formatted = client.llm.postprocess_generate(prompt, res["responses"])
full_formatted = client.llm.postprocess_generate(prompt, res["responses"])
response = full_formatted[len(prompt) + 1 :]
_echo("\n\n==Responses==\n", fg="white")
_echo(f"{prompt} ", fg=input_fg, nl=False)
_echo(formatted, fg=generated_fg)
_echo(response, fg=generated_fg)
elif output == "json":
_echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg="white")
else:
@@ -2395,6 +2403,46 @@ def query(
ctx.exit(0)
@cli.group()
def utils():
"""Utilities Subcommand group."""
@utils.command()
@click.argument(
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
)
@click.argument("prompt", type=click.STRING)
@output_option
@click.option("--format", type=click.STRING, default=None)
def get_prompt(model_name: str, prompt: str, format: str | None, output: OutputLiteral):
"""Get the default prompt used by OpenLLM."""
try:
module = openllm.utils.EnvVarMixin(model_name).module
template = module.DEFAULT_PROMPT_TEMPLATE
if callable(template):
if format is None:
raise click.BadOptionUsage(
"format",
f"{model_name} prompt requires passing '--format' (available format: {module.PROMPT_MAPPING})",
)
_prompt = template(format)
else:
_prompt = template
fully_formatted = _prompt.format(instruction=prompt)
if output == "porcelain":
_echo(f'__prompt__:"{fully_formatted}"', fg="white")
elif output == "json":
_echo(orjson.dumps({"prompt": fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg="white")
else:
_echo(f"== Prompt for {model_name} ==\n", fg="magenta")
_echo(fully_formatted, fg="white")
except AttributeError:
raise click.ClickException(f"{model_name} does not have default prompt template.") from None
def load_notebook_metadata() -> DictStrAny:
with open(os.path.join(os.path.dirname(openllm.playground.__file__), "_meta.yml"), "r") as f:
content = yaml.safe_load(f)

View File

@@ -22,7 +22,12 @@ from ...utils import is_vllm_available
_import_structure: dict[str, list[str]] = {
"configuration_llama": ["LlaMAConfig", "START_LLAMA_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"],
"configuration_llama": [
"LlaMAConfig",
"START_LLAMA_COMMAND_DOCSTRING",
"DEFAULT_PROMPT_TEMPLATE",
"PROMPT_MAPPING",
],
}
try:
@@ -44,6 +49,7 @@ else:
if t.TYPE_CHECKING:
from .configuration_llama import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_llama import PROMPT_MAPPING as PROMPT_MAPPING
from .configuration_llama import START_LLAMA_COMMAND_DOCSTRING as START_LLAMA_COMMAND_DOCSTRING
from .configuration_llama import LlaMAConfig as LlaMAConfig

View File

@@ -126,14 +126,14 @@ _v2_prompt = """{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instructi
# XXX: implement me
_v1_prompt = """{instruction}"""
_PROMPT_MAPPING = {
PROMPT_MAPPING = {
"v1": _v1_prompt,
"v2": _v2_prompt,
}
def _get_prompt(model_type: t.Literal["v1", "v2"]) -> str:
return _PROMPT_MAPPING[model_type]
return PROMPT_MAPPING[model_type]
DEFAULT_PROMPT_TEMPLATE = _get_prompt

View File

@@ -21,7 +21,7 @@ from ...utils import is_torch_available
_import_structure: dict[str, list[str]] = {
"configuration_mpt": ["MPTConfig", "START_MPT_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"],
"configuration_mpt": ["MPTConfig", "START_MPT_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE", "PROMPT_MAPPING"],
}
try:
@@ -35,6 +35,7 @@ else:
if t.TYPE_CHECKING:
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
from .configuration_mpt import PROMPT_MAPPING as PROMPT_MAPPING
from .configuration_mpt import START_MPT_COMMAND_DOCSTRING as START_MPT_COMMAND_DOCSTRING
from .configuration_mpt import MPTConfig as MPTConfig

View File

@@ -127,7 +127,7 @@ _default_prompt = """{instruction}"""
# TODO: XXX implement me
_chat_prompt = """{instruction}"""
_PROMPT_MAPPING = {
PROMPT_MAPPING = {
"default": _default_prompt,
"instruct": _instruct_prompt,
"storywriter": _default_prompt,
@@ -136,7 +136,7 @@ _PROMPT_MAPPING = {
def _get_prompt(model_type: str) -> str:
return _PROMPT_MAPPING[model_type]
return PROMPT_MAPPING[model_type]
DEFAULT_PROMPT_TEMPLATE = _get_prompt

View File

@@ -45,9 +45,11 @@ if t.TYPE_CHECKING:
BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]]
from .._types import LiteralRuntime
from .._types import P
from .._types import T
class _AnnotatedLazyLoader(LazyLoader):
DEFAULT_PROMPT_TEMPLATE: t.LiteralString | None | t.Callable[..., t.LiteralString]
class _AnnotatedLazyLoader(LazyLoader, t.Generic[T]):
DEFAULT_PROMPT_TEMPLATE: t.LiteralString | None | t.Callable[[T], t.LiteralString]
PROMPT_MAPPING: dict[T, t.LiteralString] | None
else:
_AnnotatedLazyLoader = LazyLoader
@@ -534,5 +536,5 @@ class EnvVarMixin(ReprMixin):
return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
@property
def module(self) -> _AnnotatedLazyLoader:
def module(self) -> _AnnotatedLazyLoader[t.LiteralString]:
return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")