feat(client): embeddings (#146)

This commit is contained in:
Aaron Pham
2023-07-25 05:44:21 -04:00
committed by GitHub
parent dcd34bd381
commit 1940086bec
7 changed files with 189 additions and 223 deletions

View File

@@ -474,6 +474,43 @@ LLMs into the ecosystem. Check out
[Adding a New Model Guide](https://github.com/bentoml/OpenLLM/blob/main/ADDING_NEW_MODEL.md)
to see how you can do it yourself.
### Embeddings
OpenLLM tentatively provides embeddings endpoint for supported models.
This can be accessed via `/v1/embeddings`.
To use via CLI, simply call ``openllm embed``:
```bash
openllm embed --endpoint http://localhost:3000 "I like to eat apples" -o json
{
"embeddings": [
0.006569798570126295,
-0.031249752268195152,
-0.008072729222476482,
0.00847396720200777,
-0.005293501541018486,
...<many embeddings>...
-0.002078012563288212,
-0.00676426338031888,
-0.002022686880081892
],
"num_tokens": 9
}
```
To invoke this endpoints, use ``client.embed`` from the Python SDK:
```python
import openllm
client = openllm.client.HTTPClient("http://localhost:3000")
client.embed("I like to eat apples")
```
> **Note**: Currently, only LlaMA models and variants support embeddings.
## ⚙️ Integrations
OpenLLM is not just a standalone product; it's a building block designed to

View File

@@ -0,0 +1,18 @@
Users now can call ``client.embed`` to get the embeddings from the running LLMServer
```python
client = openllm.client.HTTPClient("http://localhost:3000")
client.embed("Hello World")
client.embed(["Hello", "World"])
```
> **Note:** The ``client.embed`` is currently only implemnted for ``openllm.client.HTTPClient`` and ``openllm.client.AsyncHTTPClient``
Users can also query embeddings directly from the CLI, via ``openllm embed``:
```bash
$ openllm embed --endpoint localhost:3000 "Hello World" "My name is Susan"
[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
```

View File

@@ -48,7 +48,7 @@ if runner.supports_embeddings:
@svc.api(input=bentoml.io.JSON.from_sample(sample=["Hey Jude, welcome to the jumgle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample(sample={"embeddings": [0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076], "num_tokens": 20}), route="/v1/embeddings")
async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
responses = await runner.embeddings.async_run(phrases)
return openllm.EmbeddingsOutput(embeddings=responses["embeddings"].tolist()[0], num_tokens=responses["num_tokens"])
return openllm.EmbeddingsOutput(embeddings=t.cast(t.List[t.List[float]], responses["embeddings"].tolist())[0], num_tokens=responses["num_tokens"])
if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
@attr.define
class HfAgentInput:

View File

@@ -30,6 +30,7 @@ from bentoml._internal.resource import get_resource
from bentoml._internal.resource import system_resources
from bentoml._internal.runner.strategy import THREAD_ENVS
from .utils import DEBUG
from .utils import LazyType
from .utils import ReprMixin
@@ -115,7 +116,7 @@ def _from_system(cls: type[DynResource]) -> list[str]:
if visible_devices is None:
if cls.resource_id == "amd.com/gpu":
if not psutil.LINUX:
warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL)
if DEBUG: warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL)
return []
# ROCm does not currently have the rocm_smi wheel.

View File

@@ -43,6 +43,7 @@ import logging
import os
import pkgutil
import re
import shutil
import subprocess
import sys
import tempfile
@@ -58,6 +59,7 @@ import fs.copy
import fs.errors
import inflection
import orjson
import psutil
import yaml
from bentoml_cli.utils import BentoMLCommandGroup
from bentoml_cli.utils import opt_callback
@@ -77,6 +79,7 @@ from .utils import LazyLoader
from .utils import LazyType
from .utils import analytics
from .utils import available_devices
from .utils import bentoml_cattr
from .utils import codegen
from .utils import configure_logging
from .utils import dantic
@@ -1571,6 +1574,38 @@ def instruct(endpoint: str, timeout: int, agent: t.LiteralString, output: Output
return result
else: raise click.BadOptionUsage("agent", f"Unknown agent type {agent}")
@overload
def embed(ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, output: OutputLiteral, machine: t.Literal[True] = True) -> openllm.EmbeddingsOutput: ...
@overload
def embed(ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, output: OutputLiteral, machine: t.Literal[False] = False) -> None: ...
@cli.command()
@shared_client_options
@click.option("--server-type", type=click.Choice(["grpc", "http"]), help="Server type", default="http", show_default=True)
@click.argument("text", type=click.STRING, nargs=-1)
@machine_option(click)
@click.pass_context
def embed(ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, server_type: t.Literal["http", "grpc"], output: OutputLiteral, machine: bool) -> openllm.EmbeddingsOutput | None:
"""Get embeddings interactively, from a terminal.
\b
```bash
$ openllm embed --endpoint http://12.323.2.1:3000 "What is the meaning of life?" "How many stars are there in the sky?"
```
"""
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout)
try:
gen_embed = client.embed(list(text))
except ValueError:
raise click.ClickException(f"Endpoint {endpoint} does not support embeddings.") from None
if machine: return gen_embed
elif output == "pretty":
_echo("Generated embeddings:", fg="magenta")
_echo(gen_embed.embeddings, fg="white")
_echo("\nNumber of tokens:", fg="magenta")
_echo(gen_embed.num_tokens, fg="white")
elif output == "json": _echo(orjson.dumps(bentoml_cattr.unstructure(gen_embed), option=orjson.OPT_INDENT_2).decode(), fg="white")
else: _echo(gen_embed.embeddings, fg="white")
ctx.exit(0)
@cli.command()
@shared_client_options
@@ -1616,17 +1651,25 @@ def list_bentos(ctx: click.Context):
_echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
ctx.exit(0)
@overload
def dive_bentos(ctx: click.Context, bento: str, machine: t.Literal[True] = True, _bento_store: BentoStore = ...) -> str: ...
@overload
def dive_bentos(ctx: click.Context, bento: str, machine: t.Literal[False] = False, _bento_store: BentoStore = ...) -> None: ...
@utils_command.command()
@click.argument("bento", type=str)
@machine_option(click)
@click.pass_context
@inject
def dive_bentos(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> None:
def dive_bentos(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
try: bentomodel = _bento_store.get(bento)
except bentoml.exceptions.NotFound: ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build first`")
if "bundler" not in bentomodel.info.labels or bentomodel.info.labels["bundler"] != "openllm.bundle": ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.")
if machine: return bentomodel.path
# copy and paste this into a new shell
_echo(f"cd $(python -m bentoml get {bentomodel.tag!s} -o path)", fg="white")
if psutil.WINDOWS: subprocess.check_output([shutil.which("dir") or "dir"], cwd=bentomodel.path)
else:subprocess.check_output([shutil.which("ls") or "ls", "-R"], cwd=bentomodel.path)
ctx.exit(0)
@overload
def get_prompt(model_name: str, prompt: str, format: str | None, output: OutputLiteral, machine: t.Literal[True] = True) -> str: ...

View File

@@ -37,41 +37,27 @@ else:
if t.TYPE_CHECKING:
import transformers
from openllm._types import LiteralRuntime
class AnnotatedClient(bentoml.client.Client):
def health(self, *args: t.Any, **attrs: t.Any) -> t.Any:
...
async def async_health(self) -> t.Any:
...
def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]:
...
def metadata_v1(self) -> dict[str, t.Any]:
...
else:
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: ...
async def async_health(self) -> t.Any: ...
def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]: ...
def metadata_v1(self) -> dict[str, t.Any]: ...
def embeddings_v1(self) -> t.Sequence[float]: ...
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
def in_async_context() -> bool:
try:
_ = asyncio.get_running_loop()
return True
except RuntimeError:
return False
except RuntimeError: return False
T = t.TypeVar("T")
class ClientMeta(t.Generic[T]):
_api_version: str
_client_class: type[bentoml.client.Client]
_host: str
_port: str
@@ -82,84 +68,50 @@ class ClientMeta(t.Generic[T]):
def __init__(self, address: str, timeout: int = 30):
self._address = address
self._timeout = timeout
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
"""Initialise subclass for HTTP and gRPC client type."""
cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient
cls._api_version = api_version
@property
def _hf_agent(self) -> transformers.HfAgent:
if not self.supports_hf_agent:
raise openllm.exceptions.OpenLLMException(
f"{self.model_name} ({self.framework}) does not support running HF agent."
)
if not self.supports_hf_agent: raise openllm.exceptions.OpenLLMException(f"{self.model_name} ({self.framework}) does not support running HF agent.")
if self.__agent__ is None:
if not openllm.utils.is_transformers_supports_agent():
raise RuntimeError(
"Current 'transformers' does not support Agent."
" Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'"
)
if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'")
self.__agent__ = transformers.HfAgent(urljoin(self._address, "/hf/agent"))
return self.__agent__
@property
def _metadata(self) -> T:
if in_async_context():
return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
if in_async_context(): return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
return self.call("metadata")
@property
@abstractmethod
def model_name(self) -> str:
raise NotImplementedError
def model_name(self) -> str: raise NotImplementedError
@property
@abstractmethod
def framework(self) -> LiteralRuntime:
raise NotImplementedError
def framework(self) -> LiteralRuntime: raise NotImplementedError
@property
@abstractmethod
def timeout(self) -> int:
raise NotImplementedError
def timeout(self) -> int: raise NotImplementedError
@property
@abstractmethod
def model_id(self) -> str:
raise NotImplementedError
def model_id(self) -> str: raise NotImplementedError
@property
@abstractmethod
def configuration(self) -> dict[str, t.Any]:
raise NotImplementedError
def configuration(self) -> dict[str, t.Any]: raise NotImplementedError
@property
@abstractmethod
def supports_embeddings(self) -> bool:
raise NotImplementedError
def supports_embeddings(self) -> bool: raise NotImplementedError
@property
@abstractmethod
def supports_hf_agent(self) -> bool:
raise NotImplementedError
def supports_hf_agent(self) -> bool: raise NotImplementedError
@property
def llm(self) -> openllm.LLM[t.Any, t.Any]:
if self.__llm__ is None:
self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name)
if self.__llm__ is None: self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name)
return self.__llm__
@property
def config(self) -> openllm.LLMConfig:
return self.llm.config
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any:
return self._cached.call(f"{name}_{self._api_version}", *args, **attrs)
async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any:
return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs)
def config(self) -> openllm.LLMConfig: return self.llm.config
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: return self._cached.call(f"{name}_{self._api_version}", *args, **attrs)
async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs)
@property
def _cached(self) -> AnnotatedClient:
if self.__client__ is None:
@@ -170,73 +122,42 @@ class ClientMeta(t.Generic[T]):
def prepare(self, prompt: str, **attrs: t.Any):
return_raw_response = attrs.pop("return_raw_response", False)
return return_raw_response, *self.llm.sanitize_parameters(prompt, **attrs)
@abstractmethod
def postprocess(self, result: t.Any) -> openllm.GenerationOutput:
...
def postprocess(self, result: t.Any) -> openllm.GenerationOutput: ...
@abstractmethod
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
...
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ...
class BaseClient(ClientMeta[T]):
def health(self) -> t.Any:
raise NotImplementedError
def health(self) -> t.Any: raise NotImplementedError
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
def embed(self, prompt: list[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError
@overload
def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str:
...
def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str: ...
@overload
def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]:
...
def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]: ...
@overload
def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, **attrs: t.Any) -> openllm.GenerationOutput:
...
def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, **attrs: t.Any) -> openllm.GenerationOutput: ...
def query(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | dict[str, t.Any] | str:
# NOTE: We set use_default_prompt_template to False for now.
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
return_attrs = attrs.pop("return_attrs", False)
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(
prompt, use_default_prompt_template=use_default_prompt_template, **attrs
)
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
if in_async_context():
result = httpx.post(
urljoin(self._address, f"/{self._api_version}/generate"),
json=inputs.model_dump(),
timeout=self.timeout,
).json()
else:
result = self.call("generate", inputs.model_dump())
if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/generate"), json=inputs.model_dump(), timeout=self.timeout).json()
else: result = self.call("generate", inputs.model_dump())
r = self.postprocess(result)
if return_attrs:
return r
if return_raw_response:
return openllm.utils.bentoml_cattr.unstructure(r)
if return_attrs: return r
if return_raw_response: return openllm.utils.bentoml_cattr.unstructure(r)
return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
def predict(self, prompt: str, **attrs: t.Any) -> t.Any: return self.query(prompt, **attrs)
def ask_agent(
self,
task: str,
*,
return_code: bool = False,
remote: bool = False,
agent_type: t.LiteralString = "hf",
**attrs: t.Any,
) -> t.Any:
if agent_type == "hf":
return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
else:
raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if len(args) > 1:
raise ValueError("'args' should only take one positional argument.")
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
task = kwargs.pop("task", args[0])
return_code = kwargs.pop("return_code", False)
remote = kwargs.pop("remote", False)
@@ -245,40 +166,20 @@ class BaseClient(ClientMeta[T]):
except Exception as err:
logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err)
logger.info("Tip: LLMServer at '%s' might not support single generation yet.", self._address)
# NOTE: Scikit interface
def predict(self, prompt: str, **attrs: t.Any) -> t.Any:
return self.query(prompt, **attrs)
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
raise NotImplementedError
class BaseAsyncClient(ClientMeta[T]):
async def health(self) -> t.Any:
raise NotImplementedError
async def health(self) -> t.Any: raise NotImplementedError
async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
async def embed(self, prompt: list[str] | str) -> t.Sequence[float]: raise NotImplementedError
@overload
async def query(
self,
prompt: str,
*,
return_attrs: t.Literal[True] = True,
return_raw_response: bool | None = ...,
**attrs: t.Any,
) -> openllm.GenerationOutput:
...
async def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, return_raw_response: bool | None = ..., **attrs: t.Any) -> openllm.GenerationOutput: ...
@overload
async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str:
...
async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str: ...
@overload
async def query(
self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any
) -> dict[str, t.Any]:
...
async def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]: ...
async def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str | openllm.GenerationOutput:
# NOTE: We set use_default_prompt_template to False for now.
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
@@ -290,34 +191,18 @@ class BaseAsyncClient(ClientMeta[T]):
res = await self.acall("generate", inputs.model_dump())
r = self.postprocess(res)
if return_attrs:
return r
if return_raw_response:
return openllm.utils.bentoml_cattr.unstructure(r)
if return_attrs: return r
if return_raw_response: return openllm.utils.bentoml_cattr.unstructure(r)
return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
async def ask_agent(
self,
task: str,
*,
return_code: bool = False,
remote: bool = False,
agent_type: t.LiteralString = "hf",
**attrs: t.Any,
) -> t.Any:
async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any:
"""Async version of agent.run."""
if agent_type == "hf":
return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
else:
raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if not openllm.utils.is_transformers_supports_agent():
raise RuntimeError(
"This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0"
)
if len(args) > 1:
raise ValueError("'args' should only take one positional argument.")
if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0")
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
task = kwargs.pop("task", args[0])
return_code = kwargs.pop("return_code", False)
remote = kwargs.pop("remote", False)
@@ -366,8 +251,4 @@ class BaseAsyncClient(ClientMeta[T]):
return f"{tool_code}\n{code}"
# NOTE: Scikit interface
async def predict(self, prompt: str, **attrs: t.Any) -> t.Any:
return await self.query(prompt, **attrs)
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
raise NotImplementedError
async def predict(self, prompt: str, **attrs: t.Any) -> t.Any: return await self.query(prompt, **attrs)

View File

@@ -15,14 +15,17 @@
from __future__ import annotations
import logging
import typing as t
from urllib.parse import urljoin
from urllib.parse import urlparse
import httpx
import orjson
import openllm
from .base import BaseAsyncClient
from .base import BaseClient
from .base import in_async_context
if t.TYPE_CHECKING:
@@ -31,66 +34,41 @@ if t.TYPE_CHECKING:
else:
DictStrAny = dict
logger = logging.getLogger(__name__)
class HTTPClientMixin:
if t.TYPE_CHECKING:
@property
def _metadata(self) -> DictStrAny:
...
def _metadata(self) -> DictStrAny: ...
@property
def model_name(self) -> str:
try:
return self._metadata["model_name"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata["model_name"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def model_id(self) -> str:
try:
return self._metadata["model_name"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata["model_name"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def framework(self) -> LiteralRuntime:
try:
return self._metadata["framework"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata["framework"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def timeout(self) -> int:
try:
return self._metadata["timeout"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata["timeout"]
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def configuration(self) -> dict[str, t.Any]:
try:
return orjson.loads(self._metadata["configuration"])
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return orjson.loads(self._metadata["configuration"])
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def supports_embeddings(self) -> bool:
try:
return self._metadata.get("supports_embeddings", False)
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata.get("supports_embeddings", False)
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
@property
def supports_hf_agent(self) -> bool:
try:
return self._metadata.get("supports_hf_agent", False)
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
try: return self._metadata.get("supports_hf_agent", False)
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput:
return openllm.GenerationOutput(**result)
@@ -100,16 +78,24 @@ class HTTPClient(HTTPClientMixin, BaseClient[DictStrAny]):
address = address if "://" in address else "http://" + address
self._host, self._port = urlparse(address).netloc.split(":")
super().__init__(address, timeout)
def health(self) -> t.Any:
return self._cached.health()
def health(self) -> t.Any: return self._cached.health()
def embed(self, prompt: list[str] | str) -> openllm.EmbeddingsOutput:
if not self.supports_embeddings:
raise ValueError("This model does not support embeddings.")
if isinstance(prompt, str): prompt = [prompt]
if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=orjson.dumps(prompt).decode(), timeout=self.timeout)
else: result = self.call("embeddings", orjson.dumps(prompt).decode())
return openllm.EmbeddingsOutput(**result)
class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient[DictStrAny]):
def __init__(self, address: str, timeout: int = 30):
address = address if "://" in address else "http://" + address
self._host, self._port = urlparse(address).netloc.split(":")
super().__init__(address, timeout)
async def health(self) -> t.Any:
return await self._cached.async_health()
async def health(self) -> t.Any: return await self._cached.async_health()
async def embed(self, prompt: list[str] | str) -> openllm.EmbeddingsOutput:
if not self.supports_embeddings:
raise ValueError("This model does not support embeddings.")
if isinstance(prompt, str): prompt = [prompt]
res = await self.acall("embeddings", orjson.dumps(prompt).decode())
return openllm.EmbeddingsOutput(**res)