diff --git a/README.md b/README.md index d9f31f21..8507759a 100644 --- a/README.md +++ b/README.md @@ -474,6 +474,43 @@ LLMs into the ecosystem. Check out [Adding a New Model Guide](https://github.com/bentoml/OpenLLM/blob/main/ADDING_NEW_MODEL.md) to see how you can do it yourself. +### Embeddings + +OpenLLM tentatively provides embeddings endpoint for supported models. +This can be accessed via `/v1/embeddings`. + +To use via CLI, simply call ``openllm embed``: + +```bash +openllm embed --endpoint http://localhost:3000 "I like to eat apples" -o json +{ + "embeddings": [ + 0.006569798570126295, + -0.031249752268195152, + -0.008072729222476482, + 0.00847396720200777, + -0.005293501541018486, + ...... + -0.002078012563288212, + -0.00676426338031888, + -0.002022686880081892 + ], + "num_tokens": 9 +} +``` + +To invoke this endpoints, use ``client.embed`` from the Python SDK: + +```python +import openllm + +client = openllm.client.HTTPClient("http://localhost:3000") + +client.embed("I like to eat apples") +``` + +> **Note**: Currently, only LlaMA models and variants support embeddings. + ## ⚙️ Integrations OpenLLM is not just a standalone product; it's a building block designed to diff --git a/changelog.d/146.feature.md b/changelog.d/146.feature.md new file mode 100644 index 00000000..d0d2f67c --- /dev/null +++ b/changelog.d/146.feature.md @@ -0,0 +1,18 @@ +Users now can call ``client.embed`` to get the embeddings from the running LLMServer + + ```python + client = openllm.client.HTTPClient("http://localhost:3000") + + client.embed("Hello World") + client.embed(["Hello", "World"]) + ``` + +> **Note:** The ``client.embed`` is currently only implemnted for ``openllm.client.HTTPClient`` and ``openllm.client.AsyncHTTPClient`` + +Users can also query embeddings directly from the CLI, via ``openllm embed``: + + ```bash + $ openllm embed --endpoint localhost:3000 "Hello World" "My name is Susan" + + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + ``` diff --git a/src/openllm/_service.py b/src/openllm/_service.py index 53b3d98f..0acf36d7 100644 --- a/src/openllm/_service.py +++ b/src/openllm/_service.py @@ -48,7 +48,7 @@ if runner.supports_embeddings: @svc.api(input=bentoml.io.JSON.from_sample(sample=["Hey Jude, welcome to the jumgle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample(sample={"embeddings": [0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076], "num_tokens": 20}), route="/v1/embeddings") async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput: responses = await runner.embeddings.async_run(phrases) - return openllm.EmbeddingsOutput(embeddings=responses["embeddings"].tolist()[0], num_tokens=responses["num_tokens"]) + return openllm.EmbeddingsOutput(embeddings=t.cast(t.List[t.List[float]], responses["embeddings"].tolist())[0], num_tokens=responses["num_tokens"]) if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent(): @attr.define class HfAgentInput: diff --git a/src/openllm/_strategies.py b/src/openllm/_strategies.py index 1e103afb..189c5670 100644 --- a/src/openllm/_strategies.py +++ b/src/openllm/_strategies.py @@ -30,6 +30,7 @@ from bentoml._internal.resource import get_resource from bentoml._internal.resource import system_resources from bentoml._internal.runner.strategy import THREAD_ENVS +from .utils import DEBUG from .utils import LazyType from .utils import ReprMixin @@ -115,7 +116,7 @@ def _from_system(cls: type[DynResource]) -> list[str]: if visible_devices is None: if cls.resource_id == "amd.com/gpu": if not psutil.LINUX: - warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL) + if DEBUG: warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL) return [] # ROCm does not currently have the rocm_smi wheel. diff --git a/src/openllm/cli.py b/src/openllm/cli.py index cc43fcc0..a06e0132 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -43,6 +43,7 @@ import logging import os import pkgutil import re +import shutil import subprocess import sys import tempfile @@ -58,6 +59,7 @@ import fs.copy import fs.errors import inflection import orjson +import psutil import yaml from bentoml_cli.utils import BentoMLCommandGroup from bentoml_cli.utils import opt_callback @@ -77,6 +79,7 @@ from .utils import LazyLoader from .utils import LazyType from .utils import analytics from .utils import available_devices +from .utils import bentoml_cattr from .utils import codegen from .utils import configure_logging from .utils import dantic @@ -1571,6 +1574,38 @@ def instruct(endpoint: str, timeout: int, agent: t.LiteralString, output: Output return result else: raise click.BadOptionUsage("agent", f"Unknown agent type {agent}") +@overload +def embed(ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, output: OutputLiteral, machine: t.Literal[True] = True) -> openllm.EmbeddingsOutput: ... +@overload +def embed(ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, output: OutputLiteral, machine: t.Literal[False] = False) -> None: ... +@cli.command() +@shared_client_options +@click.option("--server-type", type=click.Choice(["grpc", "http"]), help="Server type", default="http", show_default=True) +@click.argument("text", type=click.STRING, nargs=-1) +@machine_option(click) +@click.pass_context +def embed(ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, server_type: t.Literal["http", "grpc"], output: OutputLiteral, machine: bool) -> openllm.EmbeddingsOutput | None: + """Get embeddings interactively, from a terminal. + + \b + ```bash + $ openllm embed --endpoint http://12.323.2.1:3000 "What is the meaning of life?" "How many stars are there in the sky?" + ``` + """ + client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout) + try: + gen_embed = client.embed(list(text)) + except ValueError: + raise click.ClickException(f"Endpoint {endpoint} does not support embeddings.") from None + if machine: return gen_embed + elif output == "pretty": + _echo("Generated embeddings:", fg="magenta") + _echo(gen_embed.embeddings, fg="white") + _echo("\nNumber of tokens:", fg="magenta") + _echo(gen_embed.num_tokens, fg="white") + elif output == "json": _echo(orjson.dumps(bentoml_cattr.unstructure(gen_embed), option=orjson.OPT_INDENT_2).decode(), fg="white") + else: _echo(gen_embed.embeddings, fg="white") + ctx.exit(0) @cli.command() @shared_client_options @@ -1616,17 +1651,25 @@ def list_bentos(ctx: click.Context): _echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white") ctx.exit(0) +@overload +def dive_bentos(ctx: click.Context, bento: str, machine: t.Literal[True] = True, _bento_store: BentoStore = ...) -> str: ... +@overload +def dive_bentos(ctx: click.Context, bento: str, machine: t.Literal[False] = False, _bento_store: BentoStore = ...) -> None: ... @utils_command.command() @click.argument("bento", type=str) +@machine_option(click) @click.pass_context @inject -def dive_bentos(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> None: +def dive_bentos(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None: """Dive into a BentoLLM. This is synonymous to cd $(b get : -o path).""" try: bentomodel = _bento_store.get(bento) except bentoml.exceptions.NotFound: ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build first`") if "bundler" not in bentomodel.info.labels or bentomodel.info.labels["bundler"] != "openllm.bundle": ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.") + if machine: return bentomodel.path # copy and paste this into a new shell - _echo(f"cd $(python -m bentoml get {bentomodel.tag!s} -o path)", fg="white") + if psutil.WINDOWS: subprocess.check_output([shutil.which("dir") or "dir"], cwd=bentomodel.path) + else:subprocess.check_output([shutil.which("ls") or "ls", "-R"], cwd=bentomodel.path) + ctx.exit(0) @overload def get_prompt(model_name: str, prompt: str, format: str | None, output: OutputLiteral, machine: t.Literal[True] = True) -> str: ... diff --git a/src/openllm_client/runtimes/base.py b/src/openllm_client/runtimes/base.py index ecad2c87..7e5d78ea 100644 --- a/src/openllm_client/runtimes/base.py +++ b/src/openllm_client/runtimes/base.py @@ -37,41 +37,27 @@ else: if t.TYPE_CHECKING: import transformers from openllm._types import LiteralRuntime - class AnnotatedClient(bentoml.client.Client): - def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: - ... - - async def async_health(self) -> t.Any: - ... - - def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]: - ... - - def metadata_v1(self) -> dict[str, t.Any]: - ... - -else: - transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") + def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: ... + async def async_health(self) -> t.Any: ... + def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]: ... + def metadata_v1(self) -> dict[str, t.Any]: ... + def embeddings_v1(self) -> t.Sequence[float]: ... +else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") logger = logging.getLogger(__name__) - def in_async_context() -> bool: try: _ = asyncio.get_running_loop() return True - except RuntimeError: - return False - + except RuntimeError: return False T = t.TypeVar("T") - class ClientMeta(t.Generic[T]): _api_version: str _client_class: type[bentoml.client.Client] - _host: str _port: str @@ -82,84 +68,50 @@ class ClientMeta(t.Generic[T]): def __init__(self, address: str, timeout: int = 30): self._address = address self._timeout = timeout - def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"): """Initialise subclass for HTTP and gRPC client type.""" cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient cls._api_version = api_version - @property def _hf_agent(self) -> transformers.HfAgent: - if not self.supports_hf_agent: - raise openllm.exceptions.OpenLLMException( - f"{self.model_name} ({self.framework}) does not support running HF agent." - ) + if not self.supports_hf_agent: raise openllm.exceptions.OpenLLMException(f"{self.model_name} ({self.framework}) does not support running HF agent.") if self.__agent__ is None: - if not openllm.utils.is_transformers_supports_agent(): - raise RuntimeError( - "Current 'transformers' does not support Agent." - " Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'" - ) + if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'") self.__agent__ = transformers.HfAgent(urljoin(self._address, "/hf/agent")) return self.__agent__ - @property def _metadata(self) -> T: - if in_async_context(): - return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json() + if in_async_context(): return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json() return self.call("metadata") - @property @abstractmethod - def model_name(self) -> str: - raise NotImplementedError - + def model_name(self) -> str: raise NotImplementedError @property @abstractmethod - def framework(self) -> LiteralRuntime: - raise NotImplementedError - + def framework(self) -> LiteralRuntime: raise NotImplementedError @property @abstractmethod - def timeout(self) -> int: - raise NotImplementedError - + def timeout(self) -> int: raise NotImplementedError @property @abstractmethod - def model_id(self) -> str: - raise NotImplementedError - + def model_id(self) -> str: raise NotImplementedError @property @abstractmethod - def configuration(self) -> dict[str, t.Any]: - raise NotImplementedError - + def configuration(self) -> dict[str, t.Any]: raise NotImplementedError @property @abstractmethod - def supports_embeddings(self) -> bool: - raise NotImplementedError - + def supports_embeddings(self) -> bool: raise NotImplementedError @property @abstractmethod - def supports_hf_agent(self) -> bool: - raise NotImplementedError - + def supports_hf_agent(self) -> bool: raise NotImplementedError @property def llm(self) -> openllm.LLM[t.Any, t.Any]: - if self.__llm__ is None: - self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name) + if self.__llm__ is None: self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name) return self.__llm__ - @property - def config(self) -> openllm.LLMConfig: - return self.llm.config - - def call(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: - return self._cached.call(f"{name}_{self._api_version}", *args, **attrs) - - async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: - return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs) - + def config(self) -> openllm.LLMConfig: return self.llm.config + def call(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: return self._cached.call(f"{name}_{self._api_version}", *args, **attrs) + async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> t.Any: return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs) @property def _cached(self) -> AnnotatedClient: if self.__client__ is None: @@ -170,73 +122,42 @@ class ClientMeta(t.Generic[T]): def prepare(self, prompt: str, **attrs: t.Any): return_raw_response = attrs.pop("return_raw_response", False) return return_raw_response, *self.llm.sanitize_parameters(prompt, **attrs) - @abstractmethod - def postprocess(self, result: t.Any) -> openllm.GenerationOutput: - ... - + def postprocess(self, result: t.Any) -> openllm.GenerationOutput: ... @abstractmethod - def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: - ... - + def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ... class BaseClient(ClientMeta[T]): - def health(self) -> t.Any: - raise NotImplementedError + def health(self) -> t.Any: raise NotImplementedError + def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError + def embed(self, prompt: list[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError @overload - def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str: - ... - + def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str: ... @overload - def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]: - ... - + def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]: ... @overload - def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, **attrs: t.Any) -> openllm.GenerationOutput: - ... - + def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, **attrs: t.Any) -> openllm.GenerationOutput: ... def query(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | dict[str, t.Any] | str: # NOTE: We set use_default_prompt_template to False for now. use_default_prompt_template = attrs.pop("use_default_prompt_template", False) return_attrs = attrs.pop("return_attrs", False) - return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare( - prompt, use_default_prompt_template=use_default_prompt_template, **attrs - ) + return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(prompt, use_default_prompt_template=use_default_prompt_template, **attrs) inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)) - if in_async_context(): - result = httpx.post( - urljoin(self._address, f"/{self._api_version}/generate"), - json=inputs.model_dump(), - timeout=self.timeout, - ).json() - else: - result = self.call("generate", inputs.model_dump()) + if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/generate"), json=inputs.model_dump(), timeout=self.timeout).json() + else: result = self.call("generate", inputs.model_dump()) r = self.postprocess(result) - if return_attrs: - return r - if return_raw_response: - return openllm.utils.bentoml_cattr.unstructure(r) + if return_attrs: return r + if return_raw_response: return openllm.utils.bentoml_cattr.unstructure(r) return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs) + def predict(self, prompt: str, **attrs: t.Any) -> t.Any: return self.query(prompt, **attrs) - def ask_agent( - self, - task: str, - *, - return_code: bool = False, - remote: bool = False, - agent_type: t.LiteralString = "hf", - **attrs: t.Any, - ) -> t.Any: - if agent_type == "hf": - return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs) - else: - raise RuntimeError(f"Unknown 'agent_type={agent_type}'") - + def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any: + if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs) + else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'") def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: - if len(args) > 1: - raise ValueError("'args' should only take one positional argument.") + if len(args) > 1: raise ValueError("'args' should only take one positional argument.") task = kwargs.pop("task", args[0]) return_code = kwargs.pop("return_code", False) remote = kwargs.pop("remote", False) @@ -245,40 +166,20 @@ class BaseClient(ClientMeta[T]): except Exception as err: logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err) logger.info("Tip: LLMServer at '%s' might not support single generation yet.", self._address) - # NOTE: Scikit interface - def predict(self, prompt: str, **attrs: t.Any) -> t.Any: - return self.query(prompt, **attrs) - - def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: - raise NotImplementedError class BaseAsyncClient(ClientMeta[T]): - async def health(self) -> t.Any: - raise NotImplementedError + async def health(self) -> t.Any: raise NotImplementedError + async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError + async def embed(self, prompt: list[str] | str) -> t.Sequence[float]: raise NotImplementedError @overload - async def query( - self, - prompt: str, - *, - return_attrs: t.Literal[True] = True, - return_raw_response: bool | None = ..., - **attrs: t.Any, - ) -> openllm.GenerationOutput: - ... - + async def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, return_raw_response: bool | None = ..., **attrs: t.Any) -> openllm.GenerationOutput: ... @overload - async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str: - ... - + async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str: ... @overload - async def query( - self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any - ) -> dict[str, t.Any]: - ... - + async def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]: ... async def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str | openllm.GenerationOutput: # NOTE: We set use_default_prompt_template to False for now. use_default_prompt_template = attrs.pop("use_default_prompt_template", False) @@ -290,34 +191,18 @@ class BaseAsyncClient(ClientMeta[T]): res = await self.acall("generate", inputs.model_dump()) r = self.postprocess(res) - if return_attrs: - return r - if return_raw_response: - return openllm.utils.bentoml_cattr.unstructure(r) + if return_attrs: return r + if return_raw_response: return openllm.utils.bentoml_cattr.unstructure(r) return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs) - async def ask_agent( - self, - task: str, - *, - return_code: bool = False, - remote: bool = False, - agent_type: t.LiteralString = "hf", - **attrs: t.Any, - ) -> t.Any: + async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any: """Async version of agent.run.""" - if agent_type == "hf": - return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs) - else: - raise RuntimeError(f"Unknown 'agent_type={agent_type}'") + if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs) + else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'") async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: - if not openllm.utils.is_transformers_supports_agent(): - raise RuntimeError( - "This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0" - ) - if len(args) > 1: - raise ValueError("'args' should only take one positional argument.") + if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0") + if len(args) > 1: raise ValueError("'args' should only take one positional argument.") task = kwargs.pop("task", args[0]) return_code = kwargs.pop("return_code", False) remote = kwargs.pop("remote", False) @@ -366,8 +251,4 @@ class BaseAsyncClient(ClientMeta[T]): return f"{tool_code}\n{code}" # NOTE: Scikit interface - async def predict(self, prompt: str, **attrs: t.Any) -> t.Any: - return await self.query(prompt, **attrs) - - def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: - raise NotImplementedError + async def predict(self, prompt: str, **attrs: t.Any) -> t.Any: return await self.query(prompt, **attrs) diff --git a/src/openllm_client/runtimes/http.py b/src/openllm_client/runtimes/http.py index ade5163f..c7f0a9e1 100644 --- a/src/openllm_client/runtimes/http.py +++ b/src/openllm_client/runtimes/http.py @@ -15,14 +15,17 @@ from __future__ import annotations import logging import typing as t +from urllib.parse import urljoin from urllib.parse import urlparse +import httpx import orjson import openllm from .base import BaseAsyncClient from .base import BaseClient +from .base import in_async_context if t.TYPE_CHECKING: @@ -31,66 +34,41 @@ if t.TYPE_CHECKING: else: DictStrAny = dict - logger = logging.getLogger(__name__) - class HTTPClientMixin: if t.TYPE_CHECKING: - @property - def _metadata(self) -> DictStrAny: - ... + def _metadata(self) -> DictStrAny: ... @property def model_name(self) -> str: - try: - return self._metadata["model_name"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata["model_name"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def model_id(self) -> str: - try: - return self._metadata["model_name"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata["model_name"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def framework(self) -> LiteralRuntime: - try: - return self._metadata["framework"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata["framework"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def timeout(self) -> int: - try: - return self._metadata["timeout"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata["timeout"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def configuration(self) -> dict[str, t.Any]: - try: - return orjson.loads(self._metadata["configuration"]) - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return orjson.loads(self._metadata["configuration"]) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def supports_embeddings(self) -> bool: - try: - return self._metadata.get("supports_embeddings", False) - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata.get("supports_embeddings", False) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def supports_hf_agent(self) -> bool: - try: - return self._metadata.get("supports_hf_agent", False) - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata.get("supports_hf_agent", False) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(**result) @@ -100,16 +78,24 @@ class HTTPClient(HTTPClientMixin, BaseClient[DictStrAny]): address = address if "://" in address else "http://" + address self._host, self._port = urlparse(address).netloc.split(":") super().__init__(address, timeout) - - def health(self) -> t.Any: - return self._cached.health() - + def health(self) -> t.Any: return self._cached.health() + def embed(self, prompt: list[str] | str) -> openllm.EmbeddingsOutput: + if not self.supports_embeddings: + raise ValueError("This model does not support embeddings.") + if isinstance(prompt, str): prompt = [prompt] + if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=orjson.dumps(prompt).decode(), timeout=self.timeout) + else: result = self.call("embeddings", orjson.dumps(prompt).decode()) + return openllm.EmbeddingsOutput(**result) class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient[DictStrAny]): def __init__(self, address: str, timeout: int = 30): address = address if "://" in address else "http://" + address self._host, self._port = urlparse(address).netloc.split(":") super().__init__(address, timeout) - - async def health(self) -> t.Any: - return await self._cached.async_health() + async def health(self) -> t.Any: return await self._cached.async_health() + async def embed(self, prompt: list[str] | str) -> openllm.EmbeddingsOutput: + if not self.supports_embeddings: + raise ValueError("This model does not support embeddings.") + if isinstance(prompt, str): prompt = [prompt] + res = await self.acall("embeddings", orjson.dumps(prompt).decode()) + return openllm.EmbeddingsOutput(**res)