mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-18 23:17:26 -04:00
feat: openllm.client
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -30,14 +30,14 @@ from .exceptions import MissingDependencyError
|
||||
|
||||
_import_structure = {
|
||||
"_llm": ["LLM", "Runner"],
|
||||
"cli": ["start", "start_grpc"],
|
||||
"_configuration": ["LLMConfig"],
|
||||
"_package": ["build"],
|
||||
"exceptions": [],
|
||||
"_schema": ["PromptTemplate", "GenerationInput", "GenerationOutput"],
|
||||
"_schema": ["GenerationInput", "GenerationOutput"],
|
||||
"utils": [],
|
||||
"models": [],
|
||||
"client": [],
|
||||
"cli": ["start", "start_grpc"],
|
||||
# NOTE: models
|
||||
"models.auto": ["AutoConfig", "CONFIG_MAPPING"],
|
||||
"models.flan_t5": ["FlanT5Config"],
|
||||
|
||||
@@ -109,6 +109,16 @@ def field_to_options(
|
||||
)
|
||||
|
||||
|
||||
def generate_kwargs_from_envvar(model: GenerationConfig | LLMConfig) -> dict[str, str]:
|
||||
kwargs: dict[str, t.Any] = {}
|
||||
for key, field in model.model_fields.items():
|
||||
if field.json_schema_extra is not None:
|
||||
if "env" not in field.json_schema_extra:
|
||||
raise RuntimeError(f"Invalid {model} passed. Only accept LLMConfig or LLMConfig.generation_config")
|
||||
kwargs[key] = os.environ.get(field.json_schema_extra["env"], field.default)
|
||||
return {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
|
||||
class GenerationConfig(pydantic.BaseModel):
|
||||
"""Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``,
|
||||
with some additional validation and environment constructor.
|
||||
@@ -549,9 +559,13 @@ class LLMConfig(pydantic.BaseModel, ABC):
|
||||
"""
|
||||
from_env_ = self.from_env()
|
||||
# filtered out None values
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
kwargs = {**generate_kwargs_from_envvar(self), **{k: v for k, v in kwargs.items() if v is not None}}
|
||||
if __llm_config__ is not None:
|
||||
kwargs = {**__llm_config__.model_dump(), **kwargs}
|
||||
kwargs["generation_config"] = {
|
||||
**generate_kwargs_from_envvar(self.generation_config),
|
||||
**kwargs.get("generation_config", {}),
|
||||
}
|
||||
if from_env_:
|
||||
return from_env_.model_construct(**kwargs)
|
||||
return self.model_construct(**kwargs)
|
||||
|
||||
@@ -342,7 +342,7 @@ def start_model_command(
|
||||
if _bentoml_config_options
|
||||
else ""
|
||||
+ f"api_server.timeout={server_timeout}"
|
||||
+ f' runners."llm-{llconfig.__openllm_start_name__}-runner".timeout={config.__openllm_timeout__}'
|
||||
+ f' runners."llm-{config.__openllm_start_name__}-runner".timeout={config.__openllm_timeout__}'
|
||||
)
|
||||
|
||||
start_env.update(
|
||||
@@ -355,6 +355,13 @@ def start_model_command(
|
||||
}
|
||||
)
|
||||
|
||||
if envvar == "flax":
|
||||
llm = openllm.AutoFlaxLLM.for_model(model_name, pretrained=pretrained)
|
||||
elif envvar == "tf":
|
||||
llm = openllm.AutoTFLLM.for_model(model_name, pretrained=pretrained)
|
||||
else:
|
||||
llm = openllm.AutoLLM.for_model(model_name, pretrained=pretrained)
|
||||
|
||||
if llm.requirements is not None:
|
||||
click.secho(
|
||||
f"Make sure that you have the following dependencies available: {llm.requirements}\n", fg="yellow"
|
||||
|
||||
@@ -20,5 +20,16 @@ To start interact with the server, you can do the following:
|
||||
>>> client = openllm.client.create("http://0.0.0.0:3000")
|
||||
>>> client.query("What is the meaning of life?")
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from openllm_client import for_model as for_model
|
||||
import typing as t
|
||||
|
||||
import openllm_client as _
|
||||
|
||||
|
||||
def __dir__():
|
||||
return dir(_)
|
||||
|
||||
|
||||
def __getattr__(value: t.Any) -> t.Any:
|
||||
return getattr(_, value)
|
||||
|
||||
@@ -53,3 +53,5 @@ Currently, ChatGLM only supports PyTorch. Make sure ``torch`` is available in yo
|
||||
ChatGLM Runner will use THUDM/ChatGLM-6b as the default model. To change any to any other ChatGLM
|
||||
saved pretrained, or a fine-tune ChatGLM, provide ``OPENLLM_CHATGLM_PRETRAINED='thudm/chatglm-6b-int8'``
|
||||
"""
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
|
||||
@@ -37,9 +37,9 @@ FLAN-T5 Runner will use google/flan-t5-large as the default model. To change any
|
||||
saved pretrained, or a fine-tune FLAN-T5, provide ``OPENLLM_FLAN_T5_PRETRAINED='google/flan-t5-xxl'``
|
||||
"""
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = """Please use the following piece of context to answer the question at the end.
|
||||
{context}
|
||||
Question:{question}
|
||||
DEFAULT_PROMPT_TEMPLATE = """\
|
||||
Answer the following question.
|
||||
Question:{instruction}
|
||||
Answer:"""
|
||||
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ EOD = "<|endoftext|>"
|
||||
FIM_INDICATOR = "<FILL_HERE>"
|
||||
|
||||
|
||||
class StarCoder(openllm.LLM, _internal=True, requires_gpu=True):
|
||||
class StarCoder(openllm.LLM, _internal=True):
|
||||
default_model = "bigcode/starcoder"
|
||||
|
||||
requirements = ["bitandbytes"]
|
||||
|
||||
@@ -23,12 +23,12 @@ from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from ._prompt import PromptTemplate as PromptTemplate
|
||||
from ._prompt import PromptTemplate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def for_model(address: str, kind: t.Literal["http", "grpc"] = "http", timeout: int = 30):
|
||||
def create(address: str, kind: t.Literal["http", "grpc"] = "http", timeout: int = 30):
|
||||
if kind == "http":
|
||||
from .runtimes.http import HTTPClient as _HTTPClient
|
||||
|
||||
@@ -39,3 +39,6 @@ def for_model(address: str, kind: t.Literal["http", "grpc"] = "http", timeout: i
|
||||
return _GrpcClient(address, timeout)
|
||||
else:
|
||||
raise ValueError(f"Unknown kind: {kind}. Only 'http' and 'grpc' are supported.")
|
||||
|
||||
|
||||
__all__ = ["create", "PromptTemplate"]
|
||||
|
||||
@@ -13,15 +13,21 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import string
|
||||
import typing as t
|
||||
|
||||
import pydantic
|
||||
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
DictStrStr = dict[str, str]
|
||||
else:
|
||||
DictStrStr = dict
|
||||
|
||||
|
||||
class PromptFormatter(string.Formatter):
|
||||
"""This PromptFormatter is largely based on langchain's implementation."""
|
||||
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> str:
|
||||
if len(args) > 0:
|
||||
raise ValueError("Positional arguments are not supported")
|
||||
@@ -44,14 +50,22 @@ class PromptFormatter(string.Formatter):
|
||||
_default_formatter = PromptFormatter()
|
||||
|
||||
|
||||
class PromptTemplate(pydantic.BaseModel):
|
||||
class PartialDict(DictStrStr):
|
||||
def __missing__(self, key: str):
|
||||
return "{" + key + "}"
|
||||
|
||||
|
||||
@dataclasses.dataclass(slots=True)
|
||||
class PromptTemplate:
|
||||
template: str
|
||||
input_variables: t.Sequence[str]
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
def to_str(self, **kwargs: str) -> str:
|
||||
def to_str(self, __partial_dict__: PartialDict | None = None, **kwargs: str) -> str:
|
||||
"""Generate a prompt from the template and input variables"""
|
||||
if __partial_dict__:
|
||||
return _default_formatter.vformat(self.template, (), __partial_dict__)
|
||||
if not kwargs:
|
||||
raise ValueError("Keyword arguments are required")
|
||||
if not all(k in kwargs for k in self.input_variables):
|
||||
|
||||
@@ -15,26 +15,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import bentoml
|
||||
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
class BaseClient:
|
||||
@abstractmethod
|
||||
def query(
|
||||
self,
|
||||
prompt_template: str | openllm.PromptTemplate | None = None,
|
||||
**llm_config: t.Any,
|
||||
) -> str:
|
||||
class AnnotatedClient(bentoml.client.Client):
|
||||
def health(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
...
|
||||
|
||||
def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]:
|
||||
...
|
||||
|
||||
def metadata_v1(self) -> dict[str, t.Any]:
|
||||
...
|
||||
|
||||
|
||||
class BaseClient(ABC):
|
||||
_metadata: dict[str, t.Any]
|
||||
_api_version: str
|
||||
_config_class: type[bentoml.client.Client]
|
||||
|
||||
_host: str
|
||||
_port: str
|
||||
|
||||
__config__: openllm.LLMConfig | None = None
|
||||
__client__: AnnotatedClient | None = None
|
||||
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._address = address
|
||||
self._timeout = timeout
|
||||
assert self._host and self._port, "Make sure to setup _host and _port based on your client implementation."
|
||||
self._metadata = self.call("metadata")
|
||||
|
||||
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
|
||||
cls._config_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient
|
||||
cls._api_version = api_version
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._metadata["model_name"]
|
||||
|
||||
@property
|
||||
def config(self) -> openllm.LLMConfig:
|
||||
if self.__config__ is None:
|
||||
self.__config__ = openllm.AutoConfig.for_model(self.model)
|
||||
return self.__config__
|
||||
|
||||
def call(self, name: str, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
return getattr(self._cached, f"{name}_{self._api_version}")(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def _cached(self) -> AnnotatedClient:
|
||||
if self.__client__ is None:
|
||||
self._config_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
|
||||
self.__client__ = t.cast("AnnotatedClient", self._config_class.from_url(self._address))
|
||||
return self.__client__
|
||||
|
||||
def health(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def chat(
|
||||
self,
|
||||
prompt: str,
|
||||
context: str,
|
||||
prompt_template: str | openllm.PromptTemplate | None = None,
|
||||
**llm_config: t.Any,
|
||||
) -> str:
|
||||
def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | list[t.Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -18,8 +18,6 @@ import asyncio
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
|
||||
import openllm
|
||||
|
||||
from .base import BaseClient
|
||||
@@ -30,51 +28,26 @@ if t.TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GrpcClient(BaseClient):
|
||||
__: bentoml.client.GrpcClient | None = None
|
||||
|
||||
class GrpcClient(BaseClient, client_type="grpc"):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._timeout = timeout
|
||||
self._address = address
|
||||
self._host, self._port = address.split(":")
|
||||
|
||||
@property
|
||||
def _cached(self) -> bentoml.client.GrpcClient:
|
||||
if self.__ is None:
|
||||
bentoml.client.GrpcClient.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
|
||||
self.__ = bentoml.client.GrpcClient.from_url(self._address)
|
||||
return self.__
|
||||
super().__init__(address, timeout)
|
||||
|
||||
def health(self) -> health_pb2.HealthCheckResponse:
|
||||
return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
|
||||
|
||||
def query(self, prompt_template: str | openllm.PromptTemplate | None = None, **attrs: t.Any) -> str:
|
||||
def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | list[t.Any]:
|
||||
return_raw_response = attrs.pop("return_raw_response", False)
|
||||
model_name = self._cached.model_name()
|
||||
if prompt_template is None:
|
||||
# return the default prompt
|
||||
prompt_template = openllm.PromptTemplate.from_default(model_name)
|
||||
elif isinstance(prompt_template, str):
|
||||
prompt_template = openllm.PromptTemplate.from_template(prompt_template)
|
||||
variables = {k: v for k, v in attrs.items() if k in prompt_template.input_variables}
|
||||
config = openllm.AutoConfig.for_model(model_name).with_options(
|
||||
**{k: v for k, v in attrs.items() if k not in variables}
|
||||
)
|
||||
r = openllm.schema.GenerationOutput(
|
||||
**self._cached.generate(
|
||||
openllm.schema.GenerationInput(prompt=prompt_template.to_str(**variables), llm_config=config.dict())
|
||||
r = openllm.GenerationOutput(
|
||||
**self.call(
|
||||
"generate",
|
||||
openllm.GenerationInput(
|
||||
prompt=prompt,
|
||||
llm_config=self.config.with_options(**attrs),
|
||||
),
|
||||
)
|
||||
)
|
||||
if return_raw_response:
|
||||
return r.dict()
|
||||
return r.model_dump()
|
||||
|
||||
return prompt_template.to_str(**variables) + "".join(r.responses)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
prompt: str,
|
||||
context: str,
|
||||
prompt_template: str | openllm.PromptTemplate | None = None,
|
||||
**llm_config: t.Any,
|
||||
):
|
||||
...
|
||||
return r.responses
|
||||
|
||||
@@ -18,8 +18,6 @@ import logging
|
||||
import typing as t
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bentoml
|
||||
|
||||
import openllm
|
||||
|
||||
from .base import BaseClient
|
||||
@@ -28,51 +26,26 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTTPClient(BaseClient):
|
||||
__: bentoml.client.HTTPClient | None = None
|
||||
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
address = address if "://" in address else "http://" + address
|
||||
self._timeout = timeout
|
||||
self._address = address
|
||||
self._host, self._port = urlparse(address).netloc.split(":")
|
||||
|
||||
@property
|
||||
def _cached(self) -> bentoml.client.HTTPClient:
|
||||
if self.__ is None:
|
||||
bentoml.client.HTTPClient.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
|
||||
self.__ = bentoml.client.HTTPClient.from_url(self._address)
|
||||
return self.__
|
||||
super().__init__(address, timeout)
|
||||
|
||||
def health(self) -> t.Any:
|
||||
return self._cached.health()
|
||||
|
||||
def query(self, prompt_template: str | openllm.PromptTemplate | None = None, **attrs: t.Any) -> str:
|
||||
def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | list[t.Any]:
|
||||
return_raw_response = attrs.pop("return_raw_response", False)
|
||||
model_name = self._cached.model_name()
|
||||
if prompt_template is None:
|
||||
# return the default prompt
|
||||
prompt_template = openllm.PromptTemplate.from_default(model_name)
|
||||
elif isinstance(prompt_template, str):
|
||||
prompt_template = openllm.PromptTemplate.from_template(prompt_template)
|
||||
variables = {k: v for k, v in attrs.items() if k in prompt_template.input_variables}
|
||||
config = openllm.AutoConfig.for_model(model_name).with_options(
|
||||
**{k: v for k, v in attrs.items() if k not in variables}
|
||||
)
|
||||
r = openllm.schema.GenerationOutput(
|
||||
**self._cached.generate(
|
||||
openllm.schema.GenerationInput(prompt=prompt_template.to_str(**variables), llm_config=config.dict())
|
||||
r = openllm.GenerationOutput(
|
||||
**self.call(
|
||||
"generate",
|
||||
openllm.GenerationInput(
|
||||
prompt=prompt,
|
||||
llm_config=self.config.with_options(**attrs),
|
||||
),
|
||||
)
|
||||
)
|
||||
if return_raw_response:
|
||||
return r.dict()
|
||||
return r.model_dump()
|
||||
|
||||
return prompt_template.to_str(**variables) + "".join(r.responses)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
prompt: str,
|
||||
context: str,
|
||||
prompt_template: str | openllm.PromptTemplate | None = None,
|
||||
**llm_config: t.Any,
|
||||
):
|
||||
...
|
||||
return r.responses
|
||||
|
||||
Reference in New Issue
Block a user