mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-11 03:33:44 -04:00
feat(client): embeddings (#146)
This commit is contained in:
37
README.md
37
README.md
@@ -474,6 +474,43 @@ LLMs into the ecosystem. Check out
|
||||
[Adding a New Model Guide](https://github.com/bentoml/OpenLLM/blob/main/ADDING_NEW_MODEL.md)
|
||||
to see how you can do it yourself.
|
||||
|
||||
### Embeddings
|
||||
|
||||
OpenLLM tentatively provides embeddings endpoint for supported models.
|
||||
This can be accessed via `/v1/embeddings`.
|
||||
|
||||
To use via CLI, simply call ``openllm embed``:
|
||||
|
||||
```bash
|
||||
openllm embed --endpoint http://localhost:3000 "I like to eat apples" -o json
|
||||
{
|
||||
"embeddings": [
|
||||
0.006569798570126295,
|
||||
-0.031249752268195152,
|
||||
-0.008072729222476482,
|
||||
0.00847396720200777,
|
||||
-0.005293501541018486,
|
||||
...<many embeddings>...
|
||||
-0.002078012563288212,
|
||||
-0.00676426338031888,
|
||||
-0.002022686880081892
|
||||
],
|
||||
"num_tokens": 9
|
||||
}
|
||||
```
|
||||
|
||||
To invoke this endpoints, use ``client.embed`` from the Python SDK:
|
||||
|
||||
```python
|
||||
import openllm
|
||||
|
||||
client = openllm.client.HTTPClient("http://localhost:3000")
|
||||
|
||||
client.embed("I like to eat apples")
|
||||
```
|
||||
|
||||
> **Note**: Currently, only LlaMA models and variants support embeddings.
|
||||
|
||||
## ⚙️ Integrations
|
||||
|
||||
OpenLLM is not just a standalone product; it's a building block designed to
|
||||
|
||||
18
changelog.d/146.feature.md
Normal file
18
changelog.d/146.feature.md
Normal file
@@ -0,0 +1,18 @@
|
||||
Users now can call ``client.embed`` to get the embeddings from the running LLMServer
|
||||
|
||||
```python
|
||||
client = openllm.client.HTTPClient("http://localhost:3000")
|
||||
|
||||
client.embed("Hello World")
|
||||
client.embed(["Hello", "World"])
|
||||
```
|
||||
|
||||
> **Note:** The ``client.embed`` is currently only implemnted for ``openllm.client.HTTPClient`` and ``openllm.client.AsyncHTTPClient``
|
||||
|
||||
Users can also query embeddings directly from the CLI, via ``openllm embed``:
|
||||
|
||||
```bash
|
||||
$ openllm embed --endpoint localhost:3000 "Hello World" "My name is Susan"
|
||||
|
||||
[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
```
|
||||
@@ -48,7 +48,7 @@ if runner.supports_embeddings:
|
||||
@svc.api(input=bentoml.io.JSON.from_sample(sample=["Hey Jude, welcome to the jumgle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample(sample={"embeddings": [0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076], "num_tokens": 20}), route="/v1/embeddings")
|
||||
async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
|
||||
responses = await runner.embeddings.async_run(phrases)
|
||||
return openllm.EmbeddingsOutput(embeddings=responses["embeddings"].tolist()[0], num_tokens=responses["num_tokens"])
|
||||
return openllm.EmbeddingsOutput(embeddings=t.cast(t.List[t.List[float]], responses["embeddings"].tolist())[0], num_tokens=responses["num_tokens"])
|
||||
if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
|
||||
@attr.define
|
||||
class HfAgentInput:
|
||||
|
||||
@@ -30,6 +30,7 @@ from bentoml._internal.resource import get_resource
|
||||
from bentoml._internal.resource import system_resources
|
||||
from bentoml._internal.runner.strategy import THREAD_ENVS
|
||||
|
||||
from .utils import DEBUG
|
||||
from .utils import LazyType
|
||||
from .utils import ReprMixin
|
||||
|
||||
@@ -115,7 +116,7 @@ def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
if visible_devices is None:
|
||||
if cls.resource_id == "amd.com/gpu":
|
||||
if not psutil.LINUX:
|
||||
warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL)
|
||||
if DEBUG: warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL)
|
||||
return []
|
||||
|
||||
# ROCm does not currently have the rocm_smi wheel.
|
||||
|
||||
@@ -43,6 +43,7 @@ import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -58,6 +59,7 @@ import fs.copy
|
||||
import fs.errors
|
||||
import inflection
|
||||
import orjson
|
||||
import psutil
|
||||
import yaml
|
||||
from bentoml_cli.utils import BentoMLCommandGroup
|
||||
from bentoml_cli.utils import opt_callback
|
||||
@@ -77,6 +79,7 @@ from .utils import LazyLoader
|
||||
from .utils import LazyType
|
||||
from .utils import analytics
|
||||
from .utils import available_devices
|
||||
from .utils import bentoml_cattr
|
||||
from .utils import codegen
|
||||
from .utils import configure_logging
|
||||
from .utils import dantic
|
||||
@@ -1571,6 +1574,38 @@ def instruct(endpoint: str, timeout: int, agent: t.LiteralString, output: Output
|
||||
return result
|
||||
else: raise click.BadOptionUsage("agent", f"Unknown agent type {agent}")
|
||||
|
||||
@overload
|
||||
def embed(ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, output: OutputLiteral, machine: t.Literal[True] = True) -> openllm.EmbeddingsOutput: ...
|
||||
@overload
|
||||
def embed(ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, output: OutputLiteral, machine: t.Literal[False] = False) -> None: ...
|
||||
@cli.command()
|
||||
@shared_client_options
|
||||
@click.option("--server-type", type=click.Choice(["grpc", "http"]), help="Server type", default="http", show_default=True)
|
||||
@click.argument("text", type=click.STRING, nargs=-1)
|
||||
@machine_option(click)
|
||||
@click.pass_context
|
||||
def embed(ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, server_type: t.Literal["http", "grpc"], output: OutputLiteral, machine: bool) -> openllm.EmbeddingsOutput | None:
|
||||
"""Get embeddings interactively, from a terminal.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm embed --endpoint http://12.323.2.1:3000 "What is the meaning of life?" "How many stars are there in the sky?"
|
||||
```
|
||||
"""
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout)
|
||||
try:
|
||||
gen_embed = client.embed(list(text))
|
||||
except ValueError:
|
||||
raise click.ClickException(f"Endpoint {endpoint} does not support embeddings.") from None
|
||||
if machine: return gen_embed
|
||||
elif output == "pretty":
|
||||
_echo("Generated embeddings:", fg="magenta")
|
||||
_echo(gen_embed.embeddings, fg="white")
|
||||
_echo("\nNumber of tokens:", fg="magenta")
|
||||
_echo(gen_embed.num_tokens, fg="white")
|
||||
elif output == "json": _echo(orjson.dumps(bentoml_cattr.unstructure(gen_embed), option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
else: _echo(gen_embed.embeddings, fg="white")
|
||||
ctx.exit(0)
|
||||
|
||||
@cli.command()
|
||||
@shared_client_options
|
||||
@@ -1616,17 +1651,25 @@ def list_bentos(ctx: click.Context):
|
||||
_echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
ctx.exit(0)
|
||||
|
||||
@overload
|
||||
def dive_bentos(ctx: click.Context, bento: str, machine: t.Literal[True] = True, _bento_store: BentoStore = ...) -> str: ...
|
||||
@overload
|
||||
def dive_bentos(ctx: click.Context, bento: str, machine: t.Literal[False] = False, _bento_store: BentoStore = ...) -> None: ...
|
||||
@utils_command.command()
|
||||
@click.argument("bento", type=str)
|
||||
@machine_option(click)
|
||||
@click.pass_context
|
||||
@inject
|
||||
def dive_bentos(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> None:
|
||||
def dive_bentos(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
|
||||
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
|
||||
try: bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound: ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build first`")
|
||||
if "bundler" not in bentomodel.info.labels or bentomodel.info.labels["bundler"] != "openllm.bundle": ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.")
|
||||
if machine: return bentomodel.path
|
||||
# copy and paste this into a new shell
|
||||
_echo(f"cd $(python -m bentoml get {bentomodel.tag!s} -o path)", fg="white")
|
||||
if psutil.WINDOWS: subprocess.check_output([shutil.which("dir") or "dir"], cwd=bentomodel.path)
|
||||
else:subprocess.check_output([shutil.which("ls") or "ls", "-R"], cwd=bentomodel.path)
|
||||
ctx.exit(0)
|
||||
|
||||
@overload
|
||||
def get_prompt(model_name: str, prompt: str, format: str | None, output: OutputLiteral, machine: t.Literal[True] = True) -> str: ...
|
||||
|
||||
@@ -37,41 +37,27 @@ else:
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
from openllm._types import LiteralRuntime
|
||||
|
||||
class AnnotatedClient(bentoml.client.Client):
|
||||
def health(self, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
...
|
||||
|
||||
async def async_health(self) -> t.Any:
|
||||
...
|
||||
|
||||
def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]:
|
||||
...
|
||||
|
||||
def metadata_v1(self) -> dict[str, t.Any]:
|
||||
...
|
||||
|
||||
else:
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: ...
|
||||
async def async_health(self) -> t.Any: ...
|
||||
def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]: ...
|
||||
def metadata_v1(self) -> dict[str, t.Any]: ...
|
||||
def embeddings_v1(self) -> t.Sequence[float]: ...
|
||||
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def in_async_context() -> bool:
|
||||
try:
|
||||
_ = asyncio.get_running_loop()
|
||||
return True
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
except RuntimeError: return False
|
||||
|
||||
T = t.TypeVar("T")
|
||||
|
||||
|
||||
class ClientMeta(t.Generic[T]):
|
||||
_api_version: str
|
||||
_client_class: type[bentoml.client.Client]
|
||||
|
||||
_host: str
|
||||
_port: str
|
||||
|
||||
@@ -82,84 +68,50 @@ class ClientMeta(t.Generic[T]):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._address = address
|
||||
self._timeout = timeout
|
||||
|
||||
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
|
||||
"""Initialise subclass for HTTP and gRPC client type."""
|
||||
cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient
|
||||
cls._api_version = api_version
|
||||
|
||||
@property
|
||||
def _hf_agent(self) -> transformers.HfAgent:
|
||||
if not self.supports_hf_agent:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"{self.model_name} ({self.framework}) does not support running HF agent."
|
||||
)
|
||||
if not self.supports_hf_agent: raise openllm.exceptions.OpenLLMException(f"{self.model_name} ({self.framework}) does not support running HF agent.")
|
||||
if self.__agent__ is None:
|
||||
if not openllm.utils.is_transformers_supports_agent():
|
||||
raise RuntimeError(
|
||||
"Current 'transformers' does not support Agent."
|
||||
" Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'"
|
||||
)
|
||||
if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'")
|
||||
self.__agent__ = transformers.HfAgent(urljoin(self._address, "/hf/agent"))
|
||||
return self.__agent__
|
||||
|
||||
@property
|
||||
def _metadata(self) -> T:
|
||||
if in_async_context():
|
||||
return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
|
||||
if in_async_context(): return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
|
||||
return self.call("metadata")
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def model_name(self) -> str: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def framework(self) -> LiteralRuntime:
|
||||
raise NotImplementedError
|
||||
|
||||
def framework(self) -> LiteralRuntime: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def timeout(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def timeout(self) -> int: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_id(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def model_id(self) -> str: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def configuration(self) -> dict[str, t.Any]: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def supports_embeddings(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def supports_embeddings(self) -> bool: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def supports_hf_agent(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def supports_hf_agent(self) -> bool: raise NotImplementedError
|
||||
@property
|
||||
def llm(self) -> openllm.LLM[t.Any, t.Any]:
|
||||
if self.__llm__ is None:
|
||||
self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name)
|
||||
if self.__llm__ is None: self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name)
|
||||
return self.__llm__
|
||||
|
||||
@property
|
||||
def config(self) -> openllm.LLMConfig:
|
||||
return self.llm.config
|
||||
|
||||
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
return self._cached.call(f"{name}_{self._api_version}", *args, **attrs)
|
||||
|
||||
async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs)
|
||||
|
||||
def config(self) -> openllm.LLMConfig: return self.llm.config
|
||||
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: return self._cached.call(f"{name}_{self._api_version}", *args, **attrs)
|
||||
async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs)
|
||||
@property
|
||||
def _cached(self) -> AnnotatedClient:
|
||||
if self.__client__ is None:
|
||||
@@ -170,73 +122,42 @@ class ClientMeta(t.Generic[T]):
|
||||
def prepare(self, prompt: str, **attrs: t.Any):
|
||||
return_raw_response = attrs.pop("return_raw_response", False)
|
||||
return return_raw_response, *self.llm.sanitize_parameters(prompt, **attrs)
|
||||
|
||||
@abstractmethod
|
||||
def postprocess(self, result: t.Any) -> openllm.GenerationOutput:
|
||||
...
|
||||
|
||||
def postprocess(self, result: t.Any) -> openllm.GenerationOutput: ...
|
||||
@abstractmethod
|
||||
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
...
|
||||
|
||||
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ...
|
||||
|
||||
class BaseClient(ClientMeta[T]):
|
||||
def health(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
def health(self) -> t.Any: raise NotImplementedError
|
||||
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
|
||||
def embed(self, prompt: list[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError
|
||||
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str:
|
||||
...
|
||||
|
||||
def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str: ...
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]:
|
||||
...
|
||||
|
||||
def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]: ...
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, **attrs: t.Any) -> openllm.GenerationOutput:
|
||||
...
|
||||
|
||||
def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, **attrs: t.Any) -> openllm.GenerationOutput: ...
|
||||
def query(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | dict[str, t.Any] | str:
|
||||
# NOTE: We set use_default_prompt_template to False for now.
|
||||
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
|
||||
return_attrs = attrs.pop("return_attrs", False)
|
||||
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(
|
||||
prompt, use_default_prompt_template=use_default_prompt_template, **attrs
|
||||
)
|
||||
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
|
||||
inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
|
||||
if in_async_context():
|
||||
result = httpx.post(
|
||||
urljoin(self._address, f"/{self._api_version}/generate"),
|
||||
json=inputs.model_dump(),
|
||||
timeout=self.timeout,
|
||||
).json()
|
||||
else:
|
||||
result = self.call("generate", inputs.model_dump())
|
||||
if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/generate"), json=inputs.model_dump(), timeout=self.timeout).json()
|
||||
else: result = self.call("generate", inputs.model_dump())
|
||||
r = self.postprocess(result)
|
||||
|
||||
if return_attrs:
|
||||
return r
|
||||
if return_raw_response:
|
||||
return openllm.utils.bentoml_cattr.unstructure(r)
|
||||
if return_attrs: return r
|
||||
if return_raw_response: return openllm.utils.bentoml_cattr.unstructure(r)
|
||||
return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
|
||||
def predict(self, prompt: str, **attrs: t.Any) -> t.Any: return self.query(prompt, **attrs)
|
||||
|
||||
def ask_agent(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
return_code: bool = False,
|
||||
remote: bool = False,
|
||||
agent_type: t.LiteralString = "hf",
|
||||
**attrs: t.Any,
|
||||
) -> t.Any:
|
||||
if agent_type == "hf":
|
||||
return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
|
||||
def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
|
||||
if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
if len(args) > 1:
|
||||
raise ValueError("'args' should only take one positional argument.")
|
||||
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
|
||||
task = kwargs.pop("task", args[0])
|
||||
return_code = kwargs.pop("return_code", False)
|
||||
remote = kwargs.pop("remote", False)
|
||||
@@ -245,40 +166,20 @@ class BaseClient(ClientMeta[T]):
|
||||
except Exception as err:
|
||||
logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err)
|
||||
logger.info("Tip: LLMServer at '%s' might not support single generation yet.", self._address)
|
||||
|
||||
# NOTE: Scikit interface
|
||||
def predict(self, prompt: str, **attrs: t.Any) -> t.Any:
|
||||
return self.query(prompt, **attrs)
|
||||
|
||||
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseAsyncClient(ClientMeta[T]):
|
||||
async def health(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
async def health(self) -> t.Any: raise NotImplementedError
|
||||
async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
|
||||
async def embed(self, prompt: list[str] | str) -> t.Sequence[float]: raise NotImplementedError
|
||||
|
||||
@overload
|
||||
async def query(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
return_attrs: t.Literal[True] = True,
|
||||
return_raw_response: bool | None = ...,
|
||||
**attrs: t.Any,
|
||||
) -> openllm.GenerationOutput:
|
||||
...
|
||||
|
||||
async def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, return_raw_response: bool | None = ..., **attrs: t.Any) -> openllm.GenerationOutput: ...
|
||||
@overload
|
||||
async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str:
|
||||
...
|
||||
|
||||
async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str: ...
|
||||
@overload
|
||||
async def query(
|
||||
self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any
|
||||
) -> dict[str, t.Any]:
|
||||
...
|
||||
|
||||
async def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]: ...
|
||||
async def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str | openllm.GenerationOutput:
|
||||
# NOTE: We set use_default_prompt_template to False for now.
|
||||
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
|
||||
@@ -290,34 +191,18 @@ class BaseAsyncClient(ClientMeta[T]):
|
||||
res = await self.acall("generate", inputs.model_dump())
|
||||
r = self.postprocess(res)
|
||||
|
||||
if return_attrs:
|
||||
return r
|
||||
if return_raw_response:
|
||||
return openllm.utils.bentoml_cattr.unstructure(r)
|
||||
if return_attrs: return r
|
||||
if return_raw_response: return openllm.utils.bentoml_cattr.unstructure(r)
|
||||
return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
|
||||
|
||||
async def ask_agent(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
return_code: bool = False,
|
||||
remote: bool = False,
|
||||
agent_type: t.LiteralString = "hf",
|
||||
**attrs: t.Any,
|
||||
) -> t.Any:
|
||||
async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
|
||||
"""Async version of agent.run."""
|
||||
if agent_type == "hf":
|
||||
return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
|
||||
async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
if not openllm.utils.is_transformers_supports_agent():
|
||||
raise RuntimeError(
|
||||
"This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0"
|
||||
)
|
||||
if len(args) > 1:
|
||||
raise ValueError("'args' should only take one positional argument.")
|
||||
if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0")
|
||||
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
|
||||
task = kwargs.pop("task", args[0])
|
||||
return_code = kwargs.pop("return_code", False)
|
||||
remote = kwargs.pop("remote", False)
|
||||
@@ -366,8 +251,4 @@ class BaseAsyncClient(ClientMeta[T]):
|
||||
return f"{tool_code}\n{code}"
|
||||
|
||||
# NOTE: Scikit interface
|
||||
async def predict(self, prompt: str, **attrs: t.Any) -> t.Any:
|
||||
return await self.query(prompt, **attrs)
|
||||
|
||||
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
|
||||
raise NotImplementedError
|
||||
async def predict(self, prompt: str, **attrs: t.Any) -> t.Any: return await self.query(prompt, **attrs)
|
||||
|
||||
@@ -15,14 +15,17 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import orjson
|
||||
|
||||
import openllm
|
||||
|
||||
from .base import BaseAsyncClient
|
||||
from .base import BaseClient
|
||||
from .base import in_async_context
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -31,66 +34,41 @@ if t.TYPE_CHECKING:
|
||||
else:
|
||||
DictStrAny = dict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTTPClientMixin:
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
@property
|
||||
def _metadata(self) -> DictStrAny:
|
||||
...
|
||||
def _metadata(self) -> DictStrAny: ...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try:
|
||||
return self._metadata["model_name"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata["model_name"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try:
|
||||
return self._metadata["model_name"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata["model_name"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try:
|
||||
return self._metadata["framework"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata["framework"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try:
|
||||
return self._metadata["timeout"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata["timeout"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try:
|
||||
return orjson.loads(self._metadata["configuration"])
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return orjson.loads(self._metadata["configuration"])
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
try:
|
||||
return self._metadata.get("supports_embeddings", False)
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata.get("supports_embeddings", False)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_hf_agent(self) -> bool:
|
||||
try:
|
||||
return self._metadata.get("supports_hf_agent", False)
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
try: return self._metadata.get("supports_hf_agent", False)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
return openllm.GenerationOutput(**result)
|
||||
|
||||
@@ -100,16 +78,24 @@ class HTTPClient(HTTPClientMixin, BaseClient[DictStrAny]):
|
||||
address = address if "://" in address else "http://" + address
|
||||
self._host, self._port = urlparse(address).netloc.split(":")
|
||||
super().__init__(address, timeout)
|
||||
|
||||
def health(self) -> t.Any:
|
||||
return self._cached.health()
|
||||
|
||||
def health(self) -> t.Any: return self._cached.health()
|
||||
def embed(self, prompt: list[str] | str) -> openllm.EmbeddingsOutput:
|
||||
if not self.supports_embeddings:
|
||||
raise ValueError("This model does not support embeddings.")
|
||||
if isinstance(prompt, str): prompt = [prompt]
|
||||
if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=orjson.dumps(prompt).decode(), timeout=self.timeout)
|
||||
else: result = self.call("embeddings", orjson.dumps(prompt).decode())
|
||||
return openllm.EmbeddingsOutput(**result)
|
||||
|
||||
class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient[DictStrAny]):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
address = address if "://" in address else "http://" + address
|
||||
self._host, self._port = urlparse(address).netloc.split(":")
|
||||
super().__init__(address, timeout)
|
||||
|
||||
async def health(self) -> t.Any:
|
||||
return await self._cached.async_health()
|
||||
async def health(self) -> t.Any: return await self._cached.async_health()
|
||||
async def embed(self, prompt: list[str] | str) -> openllm.EmbeddingsOutput:
|
||||
if not self.supports_embeddings:
|
||||
raise ValueError("This model does not support embeddings.")
|
||||
if isinstance(prompt, str): prompt = [prompt]
|
||||
res = await self.acall("embeddings", orjson.dumps(prompt).decode())
|
||||
return openllm.EmbeddingsOutput(**res)
|
||||
|
||||
Reference in New Issue
Block a user