From b31cd0460bd6656ac1956c2640e710f36384ff0d Mon Sep 17 00:00:00 2001 From: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Date: Thu, 20 Jul 2023 21:40:56 +0000 Subject: [PATCH] fix: correct tag inference for model-id in the case of build, the model_id is passed as a full valid tag under bento store XXX: We will need to fix this later Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> --- src/openllm/_llm.py | 21 +++++++++++++-------- src/openllm/bundle/_package.py | 2 +- src/openllm/cli.py | 27 +++++++++++++++------------ 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 4b1399fa..a4faca02 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -736,16 +736,21 @@ class LLM(LLMInterface[M, T], ReprMixin): @classmethod def _infer_tag_from_model_id(cls, model_id: str, model_version: str | None) -> bentoml.Tag: + # XXX: Fix me later, if the model is a valid tag, then we return it directly + # instead of creating a new tag from the model_id. this branch will be hit during `openllm build` try: - return bentoml.Tag.from_taglike(model_id) + return bentoml.models.get(model_id).tag except (ValueError, bentoml.exceptions.BentoMLException): - return make_tag( - model_id, - model_version=model_version, - trust_remote_code=cls.config_class.__openllm_trust_remote_code__, - implementation=cls.__llm_implementation__, - quiet=True, - ) + try: + return bentoml.Tag.from_taglike(model_id) + except (ValueError, bentoml.exceptions.BentoMLException): + return make_tag( + model_id, + model_version=model_version, + trust_remote_code=cls.config_class.__openllm_trust_remote_code__, + implementation=cls.__llm_implementation__, + quiet=True, + ) def __init__( self, diff --git a/src/openllm/bundle/_package.py b/src/openllm/bundle/_package.py index ca1eae18..d6bbe32c 100644 --- a/src/openllm/bundle/_package.py +++ b/src/openllm/bundle/_package.py @@ -175,7 +175,7 @@ def construct_python_options( if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}")) - return PythonOptions(packages=packages, wheels=wheels, lock_packages=True) + return PythonOptions(packages=packages, wheels=wheels, lock_packages=False) def construct_docker_options( diff --git a/src/openllm/cli.py b/src/openllm/cli.py index 09e09885..5c196e95 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -201,6 +201,15 @@ def model_id_option(factory: t.Any, model_env: EnvVarMixin | None = None) -> t.C ) +def model_version_option(factory: t.Any) -> t.Callable[[FC], FC]: + return factory.option( + "--model-version", + type=click.STRING, + default=None, + help="Optional model version to save for this model. It will be inferred automatically from model-id.", + ) + + def workers_per_resource_option(factory: t.Any, build: bool = False) -> t.Callable[[FC], FC]: help_str = """Number of workers per resource assigned. See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy @@ -600,6 +609,7 @@ def start_decorator( ), workers_per_resource_option(cog.optgroup), model_id_option(cog.optgroup, model_env=llm_config["env"]), + model_version_option(cog.optgroup), cog.optgroup.option( "--fast", is_flag=True, @@ -922,6 +932,7 @@ def start_bento( ctx: click.Context, server_timeout: int, model_id: str | None, + model_version: str | None, workers_per_resource: t.LiteralString | float, device: tuple[str, ...], quantize: t.Literal["int8", "int4", "gptq"] | None, @@ -1070,6 +1081,7 @@ def start_model( ctx: click.Context, server_timeout: int, model_id: str | None, + model_version: str | None, workers_per_resource: t.LiteralString | float, device: tuple[str, ...], quantize: t.Literal["int8", "int4", "gptq"] | None, @@ -1157,6 +1169,7 @@ def start_model( llm = openllm.infer_auto_class(env.framework_value).for_model( model_name, model_id=model_id, + model_version=model_version, llm_config=config, ensure_available=not fast, return_runner_kwargs=False, @@ -1254,12 +1267,7 @@ def start_model( required=False, ) @click.argument("converter", envvar="CONVERTER", type=click.STRING, default=None, required=False, metavar=None) -@click.option( - "--model-version", - type=click.STRING, - default=None, - help="Optional model version to save for this model. It will be inferred automatically from model-id.", -) +@model_version_option(click) @click.option( "--runtime", type=click.Choice(["ggml", "transformers"]), @@ -1803,12 +1811,7 @@ start, start_grpc, build, import_model, list_models = ( metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]", ) @click.option("--build-ctx", default=".", help="Build context. This is required if --adapter-id uses relative path") -@click.option( - "--model-version", - default=None, - type=click.STRING, - help="Model version provided for this 'model-id' if it is a custom path.", -) +@model_version_option(click) @click.option( "--dockerfile-template", default=None,