feat(cli): show runtime implementation

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-06-11 05:29:11 +00:00
parent 06c90c0ba3
commit 17241292da
7 changed files with 36 additions and 25 deletions

View File

@@ -179,6 +179,7 @@ unfixable = [
convention = "google"
[tool.ruff.isort]
force-single-line = true
known-first-party = ["openllm", "bentoml", 'transformers']
lines-after-imports = 2

View File

@@ -42,7 +42,13 @@ _import_structure = {
"client": [],
"cli": ["start", "start_grpc"],
# NOTE: models
"models.auto": ["AutoConfig", "CONFIG_MAPPING"],
"models.auto": [
"AutoConfig",
"CONFIG_MAPPING",
"MODEL_MAPPING_NAMES",
"MODEL_FLAX_MAPPING_NAMES",
"MODEL_TF_MAPPING_NAMES",
],
"models.flan_t5": ["FlanT5Config"],
"models.dolly_v2": ["DollyV2Config"],
"models.falcon": ["FalconConfig"],
@@ -89,7 +95,7 @@ else:
_import_structure["models.dolly_v2"].extend(["DollyV2"])
_import_structure["models.starcoder"].extend(["StarCoder"])
_import_structure["models.stablelm"].extend(["StableLM"])
_import_structure["models.auto"].extend(["AutoLLM", "MODEL_MAPPING_NAMES", "MODEL_MAPPING"])
_import_structure["models.auto"].extend(["AutoLLM", "MODEL_MAPPING"])
try:
if not utils.is_flax_available():
@@ -102,7 +108,7 @@ except MissingDependencyError:
]
else:
_import_structure["models.flan_t5"].extend(["FlaxFlanT5"])
_import_structure["models.auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING_NAMES", "MODEL_FLAX_MAPPING"])
_import_structure["models.auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
try:
if not utils.is_tf_available():
@@ -113,7 +119,7 @@ except MissingDependencyError:
_import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
else:
_import_structure["models.flan_t5"].extend(["TFFlanT5"])
_import_structure["models.auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING_NAMES", "MODEL_TF_MAPPING"])
_import_structure["models.auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
# declaration for OpenLLM-related modules
@@ -133,6 +139,9 @@ if t.TYPE_CHECKING:
from .cli import start as start
from .cli import start_grpc as start_grpc
from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
from .models.auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
from .models.auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
from .models.auto import AutoConfig as AutoConfig
from .models.chatglm import ChatGLMConfig as ChatGLMConfig
from .models.dolly_v2 import DollyV2Config as DollyV2Config
@@ -166,7 +175,6 @@ if t.TYPE_CHECKING:
from .utils.dummy_pt_objects import *
else:
from .models.auto import MODEL_MAPPING as MODEL_MAPPING
from .models.auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
from .models.auto import AutoLLM as AutoLLM
from .models.dolly_v2 import DollyV2 as DollyV2
from .models.flan_t5 import FlanT5 as FlanT5
@@ -180,7 +188,6 @@ if t.TYPE_CHECKING:
from .utils.dummy_flax_objects import *
else:
from .models.auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
from .models.auto import AutoFlaxLLM as AutoFlaxLLM
from .models.flan_t5 import FlaxFlanT5 as FlaxFlanT5
@@ -191,7 +198,6 @@ if t.TYPE_CHECKING:
from .utils.dummy_tf_objects import *
else:
from .models.auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING
from .models.auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
from .models.auto import AutoTFLLM as AutoTFLLM
from .models.flan_t5 import TFFlanT5 as TFFlanT5

View File

@@ -699,14 +699,21 @@ def cli_factory() -> click.Group:
else:
failed_initialized: list[tuple[str, Exception]] = []
json_data: dict[str, dict[t.Literal["model_id", "description"], t.Any]] = {}
json_data: dict[str, dict[t.Literal["model_id", "description", "runtime_impl"], t.Any]] = {}
converted: list[str] = []
for m in models:
try:
model = openllm.AutoLLM.for_model(m)
docs = inspect.cleandoc(model.config.__doc__ or "(No description)")
json_data[m] = {"model_id": model.model_ids, "description": docs}
runtime_impl: tuple[t.Literal["pt", "flax", "tf"], ...] = tuple()
if model.config.__openllm_model_name__ in openllm.MODEL_MAPPING_NAMES:
runtime_impl += ("pt",)
if model.config.__openllm_model_name__ in openllm.MODEL_FLAX_MAPPING_NAMES:
runtime_impl += ("flax",)
if model.config.__openllm_model_name__ in openllm.MODEL_TF_MAPPING_NAMES:
runtime_impl += ("tf",)
json_data[m] = {"model_id": model.model_ids, "description": docs, "runtime_impl": runtime_impl}
converted.extend([convert_transformers_model_name(i) for i in model.model_ids])
except Exception as err:
failed_initialized.append((m, err))
@@ -720,10 +727,10 @@ def cli_factory() -> click.Group:
tabulate.PRESERVE_WHITESPACE = True
data: list[str | tuple[str, str, list[str]]] = []
data: list[str | tuple[str, str, list[str], tuple[t.Literal["pt", "flax", "tf"], ...]]] = []
for m, v in json_data.items():
data.extend([(m, v["description"], v["model_id"])])
column_widths = [int(COLUMNS / 6), int(COLUMNS / 3 * 2), int(COLUMNS / 6)]
data.extend([(m, v["description"], v["model_id"], v["runtime_impl"])])
column_widths = [int(COLUMNS / 6), int(COLUMNS / 2), int(COLUMNS / 3), int(COLUMNS / 9)]
if len(data) == 0 and len(failed_initialized) > 0:
_echo("Exception found while parsing models:\n", fg="yellow")
@@ -735,7 +742,7 @@ def cli_factory() -> click.Group:
table = tabulate.tabulate(
data,
tablefmt="fancy_grid",
headers=["LLM", "Description", "Models Id"],
headers=["LLM", "Description", "Models Id", "Runtime"],
maxcolwidths=column_widths,
)

View File

@@ -24,6 +24,9 @@ from ... import utils
_import_structure = {
"configuration_auto": ["AutoConfig", "CONFIG_MAPPING", "CONFIG_MAPPING_NAMES"],
"modeling_auto": ["MODEL_MAPPING_NAMES"],
"modeling_flax_auto": ["MODEL_FLAX_MAPPING_NAMES"],
"modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"],
}
try:
@@ -32,7 +35,7 @@ try:
except openllm.exceptions.MissingDependencyError:
pass
else:
_import_structure["modeling_auto"] = ["AutoLLM", "MODEL_MAPPING_NAMES", "MODEL_MAPPING"]
_import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"])
try:
if not utils.is_flax_available():
@@ -40,7 +43,7 @@ try:
except openllm.exceptions.MissingDependencyError:
pass
else:
_import_structure["modeling_flax_auto"] = ["AutoFlaxLLM", "MODEL_FLAX_MAPPING_NAMES", "MODEL_FLAX_MAPPING"]
_import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
try:
if not utils.is_tf_available():
@@ -48,12 +51,15 @@ try:
except openllm.exceptions.MissingDependencyError:
pass
else:
_import_structure["modeling_tf_auto"] = ["AutoTFLLM", "MODEL_TF_MAPPING_NAMES", "MODEL_TF_MAPPING"]
_import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
if t.TYPE_CHECKING:
from .configuration_auto import CONFIG_MAPPING as CONFIG_MAPPING
from .configuration_auto import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
from .configuration_auto import AutoConfig as AutoConfig
from .modeling_auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
try:
if not utils.is_torch_available():
@@ -62,7 +68,6 @@ if t.TYPE_CHECKING:
pass
else:
from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING
from .modeling_auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
from .modeling_auto import AutoLLM as AutoLLM
try:
@@ -72,7 +77,6 @@ if t.TYPE_CHECKING:
pass
else:
from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING
from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
from .modeling_flax_auto import AutoFlaxLLM as AutoFlaxLLM
try:
@@ -82,7 +86,6 @@ if t.TYPE_CHECKING:
pass
else:
from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING
from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
from .modeling_tf_auto import AutoTFLLM as AutoTFLLM
else:
import sys

View File

@@ -33,6 +33,4 @@ class AutoFlaxLLM(metaclass=DummyMetaclass):
require_backends(self, ["flax"])
MODEL_FLAX_MAPPING_NAMES = None
MODEL_FLAX_MAPPING = None

View File

@@ -54,6 +54,4 @@ class AutoLLM(metaclass=DummyMetaclass):
require_backends(self, ["torch"])
MODEL_MAPPING_NAMES = None
MODEL_MAPPING = None

View File

@@ -33,6 +33,4 @@ class AutoTFLLM(metaclass=DummyMetaclass):
require_backends(self, ["tf"])
MODEL_TF_MAPPING_NAMES = None
MODEL_TF_MAPPING = None