feat: openllm.client

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-05-26 07:17:28 +00:00
parent ac933d60f1
commit 4127961c5c
12 changed files with 150 additions and 107 deletions

View File

@@ -30,14 +30,14 @@ from .exceptions import MissingDependencyError
_import_structure = {
"_llm": ["LLM", "Runner"],
"cli": ["start", "start_grpc"],
"_configuration": ["LLMConfig"],
"_package": ["build"],
"exceptions": [],
"_schema": ["PromptTemplate", "GenerationInput", "GenerationOutput"],
"_schema": ["GenerationInput", "GenerationOutput"],
"utils": [],
"models": [],
"client": [],
"cli": ["start", "start_grpc"],
# NOTE: models
"models.auto": ["AutoConfig", "CONFIG_MAPPING"],
"models.flan_t5": ["FlanT5Config"],

View File

@@ -109,6 +109,16 @@ def field_to_options(
)
def generate_kwargs_from_envvar(model: GenerationConfig | LLMConfig) -> dict[str, str]:
kwargs: dict[str, t.Any] = {}
for key, field in model.model_fields.items():
if field.json_schema_extra is not None:
if "env" not in field.json_schema_extra:
raise RuntimeError(f"Invalid {model} passed. Only accept LLMConfig or LLMConfig.generation_config")
kwargs[key] = os.environ.get(field.json_schema_extra["env"], field.default)
return {k: v for k, v in kwargs.items() if v is not None}
class GenerationConfig(pydantic.BaseModel):
"""Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``,
with some additional validation and environment constructor.
@@ -549,9 +559,13 @@ class LLMConfig(pydantic.BaseModel, ABC):
"""
from_env_ = self.from_env()
# filtered out None values
kwargs = {k: v for k, v in kwargs.items() if v is not None}
kwargs = {**generate_kwargs_from_envvar(self), **{k: v for k, v in kwargs.items() if v is not None}}
if __llm_config__ is not None:
kwargs = {**__llm_config__.model_dump(), **kwargs}
kwargs["generation_config"] = {
**generate_kwargs_from_envvar(self.generation_config),
**kwargs.get("generation_config", {}),
}
if from_env_:
return from_env_.model_construct(**kwargs)
return self.model_construct(**kwargs)

View File

@@ -342,7 +342,7 @@ def start_model_command(
if _bentoml_config_options
else ""
+ f"api_server.timeout={server_timeout}"
+ f' runners."llm-{llconfig.__openllm_start_name__}-runner".timeout={config.__openllm_timeout__}'
+ f' runners."llm-{config.__openllm_start_name__}-runner".timeout={config.__openllm_timeout__}'
)
start_env.update(
@@ -355,6 +355,13 @@ def start_model_command(
}
)
if envvar == "flax":
llm = openllm.AutoFlaxLLM.for_model(model_name, pretrained=pretrained)
elif envvar == "tf":
llm = openllm.AutoTFLLM.for_model(model_name, pretrained=pretrained)
else:
llm = openllm.AutoLLM.for_model(model_name, pretrained=pretrained)
if llm.requirements is not None:
click.secho(
f"Make sure that you have the following dependencies available: {llm.requirements}\n", fg="yellow"

View File

@@ -20,5 +20,16 @@ To start interact with the server, you can do the following:
>>> client = openllm.client.create("http://0.0.0.0:3000")
>>> client.query("What is the meaning of life?")
"""
from __future__ import annotations
from openllm_client import for_model as for_model
import typing as t
import openllm_client as _
def __dir__():
return dir(_)
def __getattr__(value: t.Any) -> t.Any:
return getattr(_, value)

View File

@@ -53,3 +53,5 @@ Currently, ChatGLM only supports PyTorch. Make sure ``torch`` is available in yo
ChatGLM Runner will use THUDM/ChatGLM-6b as the default model. To change any to any other ChatGLM
saved pretrained, or a fine-tune ChatGLM, provide ``OPENLLM_CHATGLM_PRETRAINED='thudm/chatglm-6b-int8'``
"""
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""

View File

@@ -37,9 +37,9 @@ FLAN-T5 Runner will use google/flan-t5-large as the default model. To change any
saved pretrained, or a fine-tune FLAN-T5, provide ``OPENLLM_FLAN_T5_PRETRAINED='google/flan-t5-xxl'``
"""
DEFAULT_PROMPT_TEMPLATE = """Please use the following piece of context to answer the question at the end.
{context}
Question:{question}
DEFAULT_PROMPT_TEMPLATE = """\
Answer the following question.
Question:{instruction}
Answer:"""

View File

@@ -37,7 +37,7 @@ EOD = "<|endoftext|>"
FIM_INDICATOR = "<FILL_HERE>"
class StarCoder(openllm.LLM, _internal=True, requires_gpu=True):
class StarCoder(openllm.LLM, _internal=True):
default_model = "bigcode/starcoder"
requirements = ["bitandbytes"]

View File

@@ -23,12 +23,12 @@ from __future__ import annotations
import logging
import typing as t
from ._prompt import PromptTemplate as PromptTemplate
from ._prompt import PromptTemplate
logger = logging.getLogger(__name__)
def for_model(address: str, kind: t.Literal["http", "grpc"] = "http", timeout: int = 30):
def create(address: str, kind: t.Literal["http", "grpc"] = "http", timeout: int = 30):
if kind == "http":
from .runtimes.http import HTTPClient as _HTTPClient
@@ -39,3 +39,6 @@ def for_model(address: str, kind: t.Literal["http", "grpc"] = "http", timeout: i
return _GrpcClient(address, timeout)
else:
raise ValueError(f"Unknown kind: {kind}. Only 'http' and 'grpc' are supported.")
__all__ = ["create", "PromptTemplate"]

View File

@@ -13,15 +13,21 @@
# limitations under the License.
from __future__ import annotations
import dataclasses
import string
import typing as t
import pydantic
import openllm
if t.TYPE_CHECKING:
DictStrStr = dict[str, str]
else:
DictStrStr = dict
class PromptFormatter(string.Formatter):
"""This PromptFormatter is largely based on langchain's implementation."""
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> str:
if len(args) > 0:
raise ValueError("Positional arguments are not supported")
@@ -44,14 +50,22 @@ class PromptFormatter(string.Formatter):
_default_formatter = PromptFormatter()
class PromptTemplate(pydantic.BaseModel):
class PartialDict(DictStrStr):
def __missing__(self, key: str):
return "{" + key + "}"
@dataclasses.dataclass(slots=True)
class PromptTemplate:
template: str
input_variables: t.Sequence[str]
model_config = {"extra": "forbid"}
def to_str(self, **kwargs: str) -> str:
def to_str(self, __partial_dict__: PartialDict | None = None, **kwargs: str) -> str:
"""Generate a prompt from the template and input variables"""
if __partial_dict__:
return _default_formatter.vformat(self.template, (), __partial_dict__)
if not kwargs:
raise ValueError("Keyword arguments are required")
if not all(k in kwargs for k in self.input_variables):

View File

@@ -15,26 +15,72 @@
from __future__ import annotations
import typing as t
from abc import abstractmethod
from abc import ABC, abstractmethod
import bentoml
import openllm
if t.TYPE_CHECKING:
class BaseClient:
@abstractmethod
def query(
self,
prompt_template: str | openllm.PromptTemplate | None = None,
**llm_config: t.Any,
) -> str:
class AnnotatedClient(bentoml.client.Client):
def health(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
...
def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]:
...
def metadata_v1(self) -> dict[str, t.Any]:
...
class BaseClient(ABC):
_metadata: dict[str, t.Any]
_api_version: str
_config_class: type[bentoml.client.Client]
_host: str
_port: str
__config__: openllm.LLMConfig | None = None
__client__: AnnotatedClient | None = None
def __init__(self, address: str, timeout: int = 30):
self._address = address
self._timeout = timeout
assert self._host and self._port, "Make sure to setup _host and _port based on your client implementation."
self._metadata = self.call("metadata")
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
cls._config_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient
cls._api_version = api_version
@property
def model(self) -> str:
return self._metadata["model_name"]
@property
def config(self) -> openllm.LLMConfig:
if self.__config__ is None:
self.__config__ = openllm.AutoConfig.for_model(self.model)
return self.__config__
def call(self, name: str, *args: t.Any, **kwargs: t.Any) -> t.Any:
return getattr(self._cached, f"{name}_{self._api_version}")(*args, **kwargs)
@property
def _cached(self) -> AnnotatedClient:
if self.__client__ is None:
self._config_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
self.__client__ = t.cast("AnnotatedClient", self._config_class.from_url(self._address))
return self.__client__
def health(self) -> t.Any:
raise NotImplementedError
@abstractmethod
def chat(
self,
prompt: str,
context: str,
prompt_template: str | openllm.PromptTemplate | None = None,
**llm_config: t.Any,
) -> str:
def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | list[t.Any]:
raise NotImplementedError
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
raise NotImplementedError

View File

@@ -18,8 +18,6 @@ import asyncio
import logging
import typing as t
import bentoml
import openllm
from .base import BaseClient
@@ -30,51 +28,26 @@ if t.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class GrpcClient(BaseClient):
__: bentoml.client.GrpcClient | None = None
class GrpcClient(BaseClient, client_type="grpc"):
def __init__(self, address: str, timeout: int = 30):
self._timeout = timeout
self._address = address
self._host, self._port = address.split(":")
@property
def _cached(self) -> bentoml.client.GrpcClient:
if self.__ is None:
bentoml.client.GrpcClient.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
self.__ = bentoml.client.GrpcClient.from_url(self._address)
return self.__
super().__init__(address, timeout)
def health(self) -> health_pb2.HealthCheckResponse:
return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
def query(self, prompt_template: str | openllm.PromptTemplate | None = None, **attrs: t.Any) -> str:
def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | list[t.Any]:
return_raw_response = attrs.pop("return_raw_response", False)
model_name = self._cached.model_name()
if prompt_template is None:
# return the default prompt
prompt_template = openllm.PromptTemplate.from_default(model_name)
elif isinstance(prompt_template, str):
prompt_template = openllm.PromptTemplate.from_template(prompt_template)
variables = {k: v for k, v in attrs.items() if k in prompt_template.input_variables}
config = openllm.AutoConfig.for_model(model_name).with_options(
**{k: v for k, v in attrs.items() if k not in variables}
)
r = openllm.schema.GenerationOutput(
**self._cached.generate(
openllm.schema.GenerationInput(prompt=prompt_template.to_str(**variables), llm_config=config.dict())
r = openllm.GenerationOutput(
**self.call(
"generate",
openllm.GenerationInput(
prompt=prompt,
llm_config=self.config.with_options(**attrs),
),
)
)
if return_raw_response:
return r.dict()
return r.model_dump()
return prompt_template.to_str(**variables) + "".join(r.responses)
def chat(
self,
prompt: str,
context: str,
prompt_template: str | openllm.PromptTemplate | None = None,
**llm_config: t.Any,
):
...
return r.responses

View File

@@ -18,8 +18,6 @@ import logging
import typing as t
from urllib.parse import urlparse
import bentoml
import openllm
from .base import BaseClient
@@ -28,51 +26,26 @@ logger = logging.getLogger(__name__)
class HTTPClient(BaseClient):
__: bentoml.client.HTTPClient | None = None
def __init__(self, address: str, timeout: int = 30):
address = address if "://" in address else "http://" + address
self._timeout = timeout
self._address = address
self._host, self._port = urlparse(address).netloc.split(":")
@property
def _cached(self) -> bentoml.client.HTTPClient:
if self.__ is None:
bentoml.client.HTTPClient.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
self.__ = bentoml.client.HTTPClient.from_url(self._address)
return self.__
super().__init__(address, timeout)
def health(self) -> t.Any:
return self._cached.health()
def query(self, prompt_template: str | openllm.PromptTemplate | None = None, **attrs: t.Any) -> str:
def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | list[t.Any]:
return_raw_response = attrs.pop("return_raw_response", False)
model_name = self._cached.model_name()
if prompt_template is None:
# return the default prompt
prompt_template = openllm.PromptTemplate.from_default(model_name)
elif isinstance(prompt_template, str):
prompt_template = openllm.PromptTemplate.from_template(prompt_template)
variables = {k: v for k, v in attrs.items() if k in prompt_template.input_variables}
config = openllm.AutoConfig.for_model(model_name).with_options(
**{k: v for k, v in attrs.items() if k not in variables}
)
r = openllm.schema.GenerationOutput(
**self._cached.generate(
openllm.schema.GenerationInput(prompt=prompt_template.to_str(**variables), llm_config=config.dict())
r = openllm.GenerationOutput(
**self.call(
"generate",
openllm.GenerationInput(
prompt=prompt,
llm_config=self.config.with_options(**attrs),
),
)
)
if return_raw_response:
return r.dict()
return r.model_dump()
return prompt_template.to_str(**variables) + "".join(r.responses)
def chat(
self,
prompt: str,
context: str,
prompt_template: str | openllm.PromptTemplate | None = None,
**llm_config: t.Any,
):
...
return r.responses