diff --git a/src/openllm/__init__.py b/src/openllm/__init__.py index 09b0cb88..ee7c7383 100644 --- a/src/openllm/__init__.py +++ b/src/openllm/__init__.py @@ -30,14 +30,14 @@ from .exceptions import MissingDependencyError _import_structure = { "_llm": ["LLM", "Runner"], - "cli": ["start", "start_grpc"], "_configuration": ["LLMConfig"], "_package": ["build"], "exceptions": [], - "_schema": ["PromptTemplate", "GenerationInput", "GenerationOutput"], + "_schema": ["GenerationInput", "GenerationOutput"], "utils": [], "models": [], "client": [], + "cli": ["start", "start_grpc"], # NOTE: models "models.auto": ["AutoConfig", "CONFIG_MAPPING"], "models.flan_t5": ["FlanT5Config"], diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index 7ddb7ccb..cd71188f 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -109,6 +109,16 @@ def field_to_options( ) +def generate_kwargs_from_envvar(model: GenerationConfig | LLMConfig) -> dict[str, str]: + kwargs: dict[str, t.Any] = {} + for key, field in model.model_fields.items(): + if field.json_schema_extra is not None: + if "env" not in field.json_schema_extra: + raise RuntimeError(f"Invalid {model} passed. Only accept LLMConfig or LLMConfig.generation_config") + kwargs[key] = os.environ.get(field.json_schema_extra["env"], field.default) + return {k: v for k, v in kwargs.items() if v is not None} + + class GenerationConfig(pydantic.BaseModel): """Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``, with some additional validation and environment constructor. @@ -549,9 +559,13 @@ class LLMConfig(pydantic.BaseModel, ABC): """ from_env_ = self.from_env() # filtered out None values - kwargs = {k: v for k, v in kwargs.items() if v is not None} + kwargs = {**generate_kwargs_from_envvar(self), **{k: v for k, v in kwargs.items() if v is not None}} if __llm_config__ is not None: kwargs = {**__llm_config__.model_dump(), **kwargs} + kwargs["generation_config"] = { + **generate_kwargs_from_envvar(self.generation_config), + **kwargs.get("generation_config", {}), + } if from_env_: return from_env_.model_construct(**kwargs) return self.model_construct(**kwargs) diff --git a/src/openllm/cli.py b/src/openllm/cli.py index 4b7aaf57..238e6823 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -342,7 +342,7 @@ def start_model_command( if _bentoml_config_options else "" + f"api_server.timeout={server_timeout}" - + f' runners."llm-{llconfig.__openllm_start_name__}-runner".timeout={config.__openllm_timeout__}' + + f' runners."llm-{config.__openllm_start_name__}-runner".timeout={config.__openllm_timeout__}' ) start_env.update( @@ -355,6 +355,13 @@ def start_model_command( } ) + if envvar == "flax": + llm = openllm.AutoFlaxLLM.for_model(model_name, pretrained=pretrained) + elif envvar == "tf": + llm = openllm.AutoTFLLM.for_model(model_name, pretrained=pretrained) + else: + llm = openllm.AutoLLM.for_model(model_name, pretrained=pretrained) + if llm.requirements is not None: click.secho( f"Make sure that you have the following dependencies available: {llm.requirements}\n", fg="yellow" diff --git a/src/openllm/client.py b/src/openllm/client.py index d5d8c43a..a9e8e88b 100644 --- a/src/openllm/client.py +++ b/src/openllm/client.py @@ -20,5 +20,16 @@ To start interact with the server, you can do the following: >>> client = openllm.client.create("http://0.0.0.0:3000") >>> client.query("What is the meaning of life?") """ +from __future__ import annotations -from openllm_client import for_model as for_model +import typing as t + +import openllm_client as _ + + +def __dir__(): + return dir(_) + + +def __getattr__(value: t.Any) -> t.Any: + return getattr(_, value) diff --git a/src/openllm/models/chatglm/configuration_chatglm.py b/src/openllm/models/chatglm/configuration_chatglm.py index 62cbc4cb..19b78de1 100644 --- a/src/openllm/models/chatglm/configuration_chatglm.py +++ b/src/openllm/models/chatglm/configuration_chatglm.py @@ -53,3 +53,5 @@ Currently, ChatGLM only supports PyTorch. Make sure ``torch`` is available in yo ChatGLM Runner will use THUDM/ChatGLM-6b as the default model. To change any to any other ChatGLM saved pretrained, or a fine-tune ChatGLM, provide ``OPENLLM_CHATGLM_PRETRAINED='thudm/chatglm-6b-int8'`` """ + +DEFAULT_PROMPT_TEMPLATE = """{instruction}""" diff --git a/src/openllm/models/flan_t5/configuration_flan_t5.py b/src/openllm/models/flan_t5/configuration_flan_t5.py index 88d2c3a3..61446c18 100644 --- a/src/openllm/models/flan_t5/configuration_flan_t5.py +++ b/src/openllm/models/flan_t5/configuration_flan_t5.py @@ -37,9 +37,9 @@ FLAN-T5 Runner will use google/flan-t5-large as the default model. To change any saved pretrained, or a fine-tune FLAN-T5, provide ``OPENLLM_FLAN_T5_PRETRAINED='google/flan-t5-xxl'`` """ -DEFAULT_PROMPT_TEMPLATE = """Please use the following piece of context to answer the question at the end. -{context} -Question:{question} +DEFAULT_PROMPT_TEMPLATE = """\ +Answer the following question. +Question:{instruction} Answer:""" diff --git a/src/openllm/models/starcoder/modeling_starcoder.py b/src/openllm/models/starcoder/modeling_starcoder.py index c3e596fa..2145e1b6 100644 --- a/src/openllm/models/starcoder/modeling_starcoder.py +++ b/src/openllm/models/starcoder/modeling_starcoder.py @@ -37,7 +37,7 @@ EOD = "<|endoftext|>" FIM_INDICATOR = "" -class StarCoder(openllm.LLM, _internal=True, requires_gpu=True): +class StarCoder(openllm.LLM, _internal=True): default_model = "bigcode/starcoder" requirements = ["bitandbytes"] diff --git a/src/openllm_client/__init__.py b/src/openllm_client/__init__.py index 959dc5ed..318a78eb 100644 --- a/src/openllm_client/__init__.py +++ b/src/openllm_client/__init__.py @@ -23,12 +23,12 @@ from __future__ import annotations import logging import typing as t -from ._prompt import PromptTemplate as PromptTemplate +from ._prompt import PromptTemplate logger = logging.getLogger(__name__) -def for_model(address: str, kind: t.Literal["http", "grpc"] = "http", timeout: int = 30): +def create(address: str, kind: t.Literal["http", "grpc"] = "http", timeout: int = 30): if kind == "http": from .runtimes.http import HTTPClient as _HTTPClient @@ -39,3 +39,6 @@ def for_model(address: str, kind: t.Literal["http", "grpc"] = "http", timeout: i return _GrpcClient(address, timeout) else: raise ValueError(f"Unknown kind: {kind}. Only 'http' and 'grpc' are supported.") + + +__all__ = ["create", "PromptTemplate"] diff --git a/src/openllm_client/_prompt.py b/src/openllm_client/_prompt.py index 828893fa..3d703c00 100644 --- a/src/openllm_client/_prompt.py +++ b/src/openllm_client/_prompt.py @@ -13,15 +13,21 @@ # limitations under the License. from __future__ import annotations +import dataclasses import string import typing as t -import pydantic - import openllm +if t.TYPE_CHECKING: + DictStrStr = dict[str, str] +else: + DictStrStr = dict + class PromptFormatter(string.Formatter): + """This PromptFormatter is largely based on langchain's implementation.""" + def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> str: if len(args) > 0: raise ValueError("Positional arguments are not supported") @@ -44,14 +50,22 @@ class PromptFormatter(string.Formatter): _default_formatter = PromptFormatter() -class PromptTemplate(pydantic.BaseModel): +class PartialDict(DictStrStr): + def __missing__(self, key: str): + return "{" + key + "}" + + +@dataclasses.dataclass(slots=True) +class PromptTemplate: template: str input_variables: t.Sequence[str] model_config = {"extra": "forbid"} - def to_str(self, **kwargs: str) -> str: + def to_str(self, __partial_dict__: PartialDict | None = None, **kwargs: str) -> str: """Generate a prompt from the template and input variables""" + if __partial_dict__: + return _default_formatter.vformat(self.template, (), __partial_dict__) if not kwargs: raise ValueError("Keyword arguments are required") if not all(k in kwargs for k in self.input_variables): diff --git a/src/openllm_client/runtimes/base.py b/src/openllm_client/runtimes/base.py index 1c0ae4c1..76337476 100644 --- a/src/openllm_client/runtimes/base.py +++ b/src/openllm_client/runtimes/base.py @@ -15,26 +15,72 @@ from __future__ import annotations import typing as t -from abc import abstractmethod +from abc import ABC, abstractmethod + +import bentoml import openllm +if t.TYPE_CHECKING: -class BaseClient: - @abstractmethod - def query( - self, - prompt_template: str | openllm.PromptTemplate | None = None, - **llm_config: t.Any, - ) -> str: + class AnnotatedClient(bentoml.client.Client): + def health(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + ... + + def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]: + ... + + def metadata_v1(self) -> dict[str, t.Any]: + ... + + +class BaseClient(ABC): + _metadata: dict[str, t.Any] + _api_version: str + _config_class: type[bentoml.client.Client] + + _host: str + _port: str + + __config__: openllm.LLMConfig | None = None + __client__: AnnotatedClient | None = None + + def __init__(self, address: str, timeout: int = 30): + self._address = address + self._timeout = timeout + assert self._host and self._port, "Make sure to setup _host and _port based on your client implementation." + self._metadata = self.call("metadata") + + def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"): + cls._config_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient + cls._api_version = api_version + + @property + def model(self) -> str: + return self._metadata["model_name"] + + @property + def config(self) -> openllm.LLMConfig: + if self.__config__ is None: + self.__config__ = openllm.AutoConfig.for_model(self.model) + return self.__config__ + + def call(self, name: str, *args: t.Any, **kwargs: t.Any) -> t.Any: + return getattr(self._cached, f"{name}_{self._api_version}")(*args, **kwargs) + + @property + def _cached(self) -> AnnotatedClient: + if self.__client__ is None: + self._config_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout) + self.__client__ = t.cast("AnnotatedClient", self._config_class.from_url(self._address)) + return self.__client__ + + def health(self) -> t.Any: raise NotImplementedError @abstractmethod - def chat( - self, - prompt: str, - context: str, - prompt_template: str | openllm.PromptTemplate | None = None, - **llm_config: t.Any, - ) -> str: + def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | list[t.Any]: + raise NotImplementedError + + def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError diff --git a/src/openllm_client/runtimes/grpc.py b/src/openllm_client/runtimes/grpc.py index e1bd54a4..6bd04d5c 100644 --- a/src/openllm_client/runtimes/grpc.py +++ b/src/openllm_client/runtimes/grpc.py @@ -18,8 +18,6 @@ import asyncio import logging import typing as t -import bentoml - import openllm from .base import BaseClient @@ -30,51 +28,26 @@ if t.TYPE_CHECKING: logger = logging.getLogger(__name__) -class GrpcClient(BaseClient): - __: bentoml.client.GrpcClient | None = None - +class GrpcClient(BaseClient, client_type="grpc"): def __init__(self, address: str, timeout: int = 30): - self._timeout = timeout - self._address = address self._host, self._port = address.split(":") - - @property - def _cached(self) -> bentoml.client.GrpcClient: - if self.__ is None: - bentoml.client.GrpcClient.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout) - self.__ = bentoml.client.GrpcClient.from_url(self._address) - return self.__ + super().__init__(address, timeout) def health(self) -> health_pb2.HealthCheckResponse: return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService")) - def query(self, prompt_template: str | openllm.PromptTemplate | None = None, **attrs: t.Any) -> str: + def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | list[t.Any]: return_raw_response = attrs.pop("return_raw_response", False) - model_name = self._cached.model_name() - if prompt_template is None: - # return the default prompt - prompt_template = openllm.PromptTemplate.from_default(model_name) - elif isinstance(prompt_template, str): - prompt_template = openllm.PromptTemplate.from_template(prompt_template) - variables = {k: v for k, v in attrs.items() if k in prompt_template.input_variables} - config = openllm.AutoConfig.for_model(model_name).with_options( - **{k: v for k, v in attrs.items() if k not in variables} - ) - r = openllm.schema.GenerationOutput( - **self._cached.generate( - openllm.schema.GenerationInput(prompt=prompt_template.to_str(**variables), llm_config=config.dict()) + r = openllm.GenerationOutput( + **self.call( + "generate", + openllm.GenerationInput( + prompt=prompt, + llm_config=self.config.with_options(**attrs), + ), ) ) if return_raw_response: - return r.dict() + return r.model_dump() - return prompt_template.to_str(**variables) + "".join(r.responses) - - def chat( - self, - prompt: str, - context: str, - prompt_template: str | openllm.PromptTemplate | None = None, - **llm_config: t.Any, - ): - ... + return r.responses diff --git a/src/openllm_client/runtimes/http.py b/src/openllm_client/runtimes/http.py index 530da0e5..3d082217 100644 --- a/src/openllm_client/runtimes/http.py +++ b/src/openllm_client/runtimes/http.py @@ -18,8 +18,6 @@ import logging import typing as t from urllib.parse import urlparse -import bentoml - import openllm from .base import BaseClient @@ -28,51 +26,26 @@ logger = logging.getLogger(__name__) class HTTPClient(BaseClient): - __: bentoml.client.HTTPClient | None = None - def __init__(self, address: str, timeout: int = 30): address = address if "://" in address else "http://" + address - self._timeout = timeout - self._address = address self._host, self._port = urlparse(address).netloc.split(":") - - @property - def _cached(self) -> bentoml.client.HTTPClient: - if self.__ is None: - bentoml.client.HTTPClient.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout) - self.__ = bentoml.client.HTTPClient.from_url(self._address) - return self.__ + super().__init__(address, timeout) def health(self) -> t.Any: return self._cached.health() - def query(self, prompt_template: str | openllm.PromptTemplate | None = None, **attrs: t.Any) -> str: + def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | list[t.Any]: return_raw_response = attrs.pop("return_raw_response", False) - model_name = self._cached.model_name() - if prompt_template is None: - # return the default prompt - prompt_template = openllm.PromptTemplate.from_default(model_name) - elif isinstance(prompt_template, str): - prompt_template = openllm.PromptTemplate.from_template(prompt_template) - variables = {k: v for k, v in attrs.items() if k in prompt_template.input_variables} - config = openllm.AutoConfig.for_model(model_name).with_options( - **{k: v for k, v in attrs.items() if k not in variables} - ) - r = openllm.schema.GenerationOutput( - **self._cached.generate( - openllm.schema.GenerationInput(prompt=prompt_template.to_str(**variables), llm_config=config.dict()) + r = openllm.GenerationOutput( + **self.call( + "generate", + openllm.GenerationInput( + prompt=prompt, + llm_config=self.config.with_options(**attrs), + ), ) ) if return_raw_response: - return r.dict() + return r.model_dump() - return prompt_template.to_str(**variables) + "".join(r.responses) - - def chat( - self, - prompt: str, - context: str, - prompt_template: str | openllm.PromptTemplate | None = None, - **llm_config: t.Any, - ): - ... + return r.responses