refactor: packages (#249)

This commit is contained in:
Aaron Pham
2023-08-22 08:55:46 -04:00
committed by GitHub
parent a964e659c1
commit 3ffb25a872
148 changed files with 2899 additions and 1937 deletions

View File

@@ -20,10 +20,9 @@ bentomodel = openllm.import_model("falcon", model_id='tiiuae/falcon-7b-instruct'
```
"""
from __future__ import annotations
import functools, http.client, inspect, itertools, logging, os, platform, re, subprocess, sys, time, traceback, typing as t
import attr, click, click_option_group as cog, fs, fs.copy, fs.errors, inflection, orjson, bentoml, openllm
from bentoml_cli.utils import BentoMLCommandGroup, opt_callback
import functools, http.client, inspect, itertools, logging, os, platform, re, subprocess, sys, time, traceback, typing as t, attr, click, click_option_group as cog, fs, fs.copy, fs.errors, inflection, orjson, bentoml, openllm
from simple_di import Provide, inject
from bentoml_cli.utils import BentoMLCommandGroup, opt_callback
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelStore
from . import termui
@@ -56,8 +55,8 @@ from openllm.models.auto import (
AutoConfig,
AutoLLM,
)
from openllm._typing_compat import DictStrAny, ParamSpec, Concatenate, LiteralString, Self, LiteralRuntime
from openllm.utils import (
from openllm_core._typing_compat import DictStrAny, ParamSpec, Concatenate, LiteralString, Self, LiteralRuntime
from openllm_core.utils import (
DEBUG,
DEBUG_ENV_VAR,
OPTIONAL_DEPENDENCIES,
@@ -72,21 +71,20 @@ from openllm.utils import (
first_not_none,
get_debug_mode,
get_quiet_mode,
infer_auto_class,
is_torch_available,
is_transformers_supports_agent,
resolve_user_filepath,
set_debug_mode,
set_quiet_mode,
)
from openllm.utils import infer_auto_class
if t.TYPE_CHECKING:
import torch
from bentoml._internal.bento import BentoStore
from bentoml._internal.container import DefaultBuilder
from openllm.client import BaseClient
from openllm._schema import EmbeddingsOutput
from openllm.bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
from openllm_core._schema import EmbeddingsOutput
from openllm_core._typing_compat import LiteralContainerRegistry, LiteralContainerVersionStrategy
else: torch = LazyLoader("torch", globals(), "torch")
P = ParamSpec("P")
@@ -271,7 +269,7 @@ def cli() -> None:
\b
An open platform for operating large language models in production.
Fine-tune, serve, deploy, and monitor any LLMs with ease.
""" # noqa: D205
"""
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name="start", aliases=["start-http"])
def start_command() -> None:
@@ -670,10 +668,8 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output:
"""
client = openllm.client.HTTPClient(endpoint, timeout=timeout)
try:
client.call("metadata")
except http.client.BadStatusLine:
raise click.ClickException(f"{endpoint} is neither a HTTP server nor reachable.") from None
try: client.call("metadata")
except http.client.BadStatusLine: raise click.ClickException(f"{endpoint} is neither a HTTP server nor reachable.") from None
if agent == "hf":
if not is_transformers_supports_agent(): raise click.UsageError("Transformers version should be at least 4.29 to support HfAgent. Upgrade with 'pip install -U transformers'")
_memoized = {k: v[0] for k, v in _memoized.items() if v}
@@ -700,7 +696,7 @@ def embed_command(ctx: click.Context, text: tuple[str, ...], endpoint: str, time
$ openllm embed --endpoint http://12.323.2.1:3000 "What is the meaning of life?" "How many stars are there in the sky?"
```
"""
client = t.cast("BaseClient[t.Any]", openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout))
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout)
try:
gen_embed = client.embed(text)
except ValueError:
@@ -733,14 +729,14 @@ def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: in
"""
_memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v}
if server_type == "grpc": endpoint = re.sub(r"http://", "", endpoint)
client = t.cast("BaseClient[t.Any]", openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout))
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout)
input_fg, generated_fg = "magenta", "cyan"
if output != "porcelain":
termui.echo("==Input==\n", fg="white")
termui.echo(f"{prompt}", fg=input_fg)
res = client.query(prompt, return_response="raw", **{**client.configuration, **_memoized})
if output == "pretty":
response = client.llm.postprocess_generate(prompt, res["responses"])
response = client.config.postprocess_generate(prompt, res["responses"])
termui.echo("\n\n==Responses==\n", fg="white")
termui.echo(response, fg=generated_fg)
elif output == "json":