diff --git a/src/openllm/bundle/_package.py b/src/openllm/bundle/_package.py index d6bbe32c..2f73c6f9 100644 --- a/src/openllm/bundle/_package.py +++ b/src/openllm/bundle/_package.py @@ -39,6 +39,7 @@ from ..exceptions import OpenLLMException from ..utils import DEBUG from ..utils import EnvVarMixin from ..utils import codegen +from ..utils import gpu_count from ..utils import is_flax_available from ..utils import is_tf_available from ..utils import is_torch_available @@ -236,7 +237,7 @@ def create_bento( bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.Any], - workers_per_resource: int | float, + workers_per_resource: str | int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, dockerfile_template: str | None, @@ -249,12 +250,29 @@ def create_bento( _model_store: ModelStore = Provide[BentoMLContainer.model_store], ) -> bentoml.Bento: framework_envvar = llm.config["env"]["framework_value"] - labels = dict(llm.identifying_params) + labels = {"model_ids": llm.identifying_params["model_ids"]} labels.update({"_type": llm.llm_type, "_framework": framework_envvar, "start_name": llm.config["start_name"]}) if adapter_map: labels.update(adapter_map) + if isinstance(workers_per_resource, str): + if workers_per_resource == "round_robin": + workers_per_resource = 1.0 + elif workers_per_resource == "conserved": + available_gpu = gpu_count() + if len(available_gpu) != 0: + workers_per_resource = float(1 / len(available_gpu)) + else: + workers_per_resource = 1.0 + else: + try: + workers_per_resource = float(workers_per_resource) + except ValueError: + raise ValueError( + "'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies." + ) from None + logger.info("Building Bento for '%s'", llm.config["start_name"]) if adapter_map is not None: diff --git a/src/openllm/cli.py b/src/openllm/cli.py index 5fb88558..9dedd64c 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -2472,6 +2472,11 @@ def get_prompt(model_name: str, prompt: str, format: str | None, output: OutputL "format", f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})", ) + if format not in module.PROMPT_MAPPING: + raise click.BadOptionUsage( + "format", + f"Given format {format} is not valid for {model_name} (available format: {list(module.PROMPT_MAPPING)})", + ) _prompt = template(format) else: _prompt = template @@ -2479,7 +2484,7 @@ def get_prompt(model_name: str, prompt: str, format: str | None, output: OutputL fully_formatted = _prompt.format(instruction=prompt) if output == "porcelain": - _echo(f'__prompt__:"{fully_formatted}"', fg="white") + _echo(repr(fully_formatted), fg="white") elif output == "json": _echo(orjson.dumps({"prompt": fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg="white") else: