mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-05 07:36:15 -05:00
refactor: packages (#249)
This commit is contained in:
@@ -9,13 +9,18 @@ deploy, and monitor any LLMs with ease.
|
||||
* Native integration with BentoML and LangChain for custom LLM apps
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging as _logging, os as _os, typing as _t, warnings as _warnings
|
||||
import logging as _logging, os as _os, typing as _t, warnings as _warnings, openllm_core
|
||||
from pathlib import Path as _Path
|
||||
from . import exceptions as exceptions, utils as utils
|
||||
|
||||
if utils.DEBUG:
|
||||
utils.set_debug_mode(True)
|
||||
utils.set_quiet_mode(False)
|
||||
from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams
|
||||
from openllm_core._strategies import CascadingResourceStrategy as CascadingResourceStrategy, get_resource as get_resource
|
||||
from openllm_core._schema import EmbeddingsOutput as EmbeddingsOutput, GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput, unmarshal_vllm_outputs as unmarshal_vllm_outputs
|
||||
from openllm_core.config import AutoConfig as AutoConfig, CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig
|
||||
|
||||
if openllm_core.utils.DEBUG:
|
||||
openllm_core.utils.set_debug_mode(True)
|
||||
openllm_core.utils.set_quiet_mode(False)
|
||||
_logging.basicConfig(level=_logging.NOTSET)
|
||||
else:
|
||||
# configuration for bitsandbytes before import
|
||||
@@ -28,40 +33,26 @@ else:
|
||||
_warnings.filterwarnings("ignore", message="Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated")
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"exceptions": [], "models": [], "client": [], "bundle": [], "playground": [], "testing": [], "utils": ["infer_auto_class"], "serialisation": ["ggml", "transformers"], "cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"],
|
||||
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable", "LLMEmbeddings"], "_configuration": ["LLMConfig", "GenerationConfig", "SamplingParams"], "_generation": ["StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList", "prepare_logits_processor"],
|
||||
"_quantisation": ["infer_quantisation_config"], "_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput", "unmarshal_vllm_outputs", "HfAgentInput"], "_embeddings": ["GenericEmbeddingRunnable"], "_strategies": ["CascadingResourceStrategy", "get_resource"],
|
||||
"models.auto": ["AutoConfig", "CONFIG_MAPPING", "MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"], "models.chatglm": ["ChatGLMConfig"], "models.baichuan": ["BaichuanConfig"], "models.dolly_v2": ["DollyV2Config"], "models.falcon": ["FalconConfig"], "models.flan_t5": ["FlanT5Config"], "models.gpt_neox": ["GPTNeoXConfig"], "models.llama": ["LlamaConfig"], "models.mpt": ["MPTConfig"], "models.opt": ["OPTConfig"], "models.stablelm": ["StableLMConfig"], "models.starcoder": ["StarCoderConfig"]
|
||||
"exceptions": [], "models": [], "client": [], "bundle": [], "playground": [], "testing": [],
|
||||
"utils": ["infer_auto_class"], "serialisation": ["ggml", "transformers"], "cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"], "_quantisation": ["infer_quantisation_config"], "_embeddings": ["GenericEmbeddingRunnable"],
|
||||
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable", "LLMEmbeddings"], "_generation": ["StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList", "prepare_logits_processor"],
|
||||
"models.auto": ["MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"], "models.chatglm": [], "models.baichuan": [], "models.dolly_v2": [], "models.falcon": [], "models.flan_t5": [], "models.gpt_neox": [], "models.llama": [], "models.mpt": [], "models.opt": [], "models.stablelm": [], "models.starcoder": []
|
||||
}
|
||||
COMPILED = _Path(__file__).suffix in (".pyd", ".so")
|
||||
|
||||
if _t.TYPE_CHECKING:
|
||||
from . import bundle as bundle, cli as cli, client as client, models as models, playground as playground, serialisation as serialisation, testing as testing
|
||||
from ._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams
|
||||
from ._generation import LogitsProcessorList as LogitsProcessorList, StopOnTokens as StopOnTokens, StoppingCriteriaList as StoppingCriteriaList, StopSequenceCriteria as StopSequenceCriteria, prepare_logits_processor as prepare_logits_processor
|
||||
from ._llm import LLM as LLM, LLMEmbeddings as LLMEmbeddings, LLMRunnable as LLMRunnable, LLMRunner as LLMRunner, Runner as Runner
|
||||
from ._quantisation import infer_quantisation_config as infer_quantisation_config
|
||||
from ._schema import EmbeddingsOutput as EmbeddingsOutput, GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput, unmarshal_vllm_outputs as unmarshal_vllm_outputs
|
||||
from ._embeddings import GenericEmbeddingRunnable as GenericEmbeddingRunnable
|
||||
from ._strategies import CascadingResourceStrategy as CascadingResourceStrategy, get_resource as get_resource
|
||||
from .cli._sdk import build as build, import_model as import_model, list_models as list_models, start as start, start_grpc as start_grpc
|
||||
from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING, MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES, AutoConfig as AutoConfig
|
||||
from .models.baichuan import BaichuanConfig as BaichuanConfig
|
||||
from .models.chatglm import ChatGLMConfig as ChatGLMConfig
|
||||
from .models.dolly_v2 import DollyV2Config as DollyV2Config
|
||||
from .models.falcon import FalconConfig as FalconConfig
|
||||
from .models.flan_t5 import FlanT5Config as FlanT5Config
|
||||
from .models.gpt_neox import GPTNeoXConfig as GPTNeoXConfig
|
||||
from .models.llama import LlamaConfig as LlamaConfig
|
||||
from .models.mpt import MPTConfig as MPTConfig
|
||||
from .models.opt import OPTConfig as OPTConfig
|
||||
from .models.stablelm import StableLMConfig as StableLMConfig
|
||||
from .models.starcoder import StarCoderConfig as StarCoderConfig
|
||||
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
|
||||
from .serialisation import ggml as ggml, transformers as transformers
|
||||
from .utils import infer_auto_class as infer_auto_class
|
||||
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise exceptions.MissingDependencyError
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_cpm_kernels_available()): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_pt_objects"] = ["ChatGLM", "Baichuan"]
|
||||
else:
|
||||
@@ -71,7 +62,7 @@ else:
|
||||
from .models.baichuan import Baichuan as Baichuan
|
||||
from .models.chatglm import ChatGLM as ChatGLM
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_triton_available()): raise exceptions.MissingDependencyError
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_triton_available()): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["MPT"])
|
||||
else: _import_structure["utils.dummy_pt_objects"] = ["MPT"]
|
||||
@@ -79,7 +70,7 @@ else:
|
||||
_import_structure["models.mpt"].extend(["MPT"])
|
||||
if _t.TYPE_CHECKING: from .models.mpt import MPT as MPT
|
||||
try:
|
||||
if not (utils.is_torch_available() and utils.is_einops_available()): raise exceptions.MissingDependencyError
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_einops_available()): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["Falcon"])
|
||||
else: _import_structure["utils.dummy_pt_objects"] = ["Falcon"]
|
||||
@@ -88,7 +79,7 @@ else:
|
||||
if _t.TYPE_CHECKING: from .models.falcon import Falcon as Falcon
|
||||
|
||||
try:
|
||||
if not utils.is_torch_available(): raise exceptions.MissingDependencyError
|
||||
if not openllm_core.utils.is_torch_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(utils.dummy_pt_objects) if not name.startswith("_") and name not in ("ChatGLM", "Baichuan", "MPT", "Falcon", "annotations")]
|
||||
else:
|
||||
@@ -110,7 +101,7 @@ else:
|
||||
from .models.stablelm import StableLM as StableLM
|
||||
from .models.starcoder import StarCoder as StarCoder
|
||||
try:
|
||||
if not utils.is_vllm_available(): raise exceptions.MissingDependencyError
|
||||
if not openllm_core.utils.is_vllm_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_vllm_objects"] = [name for name in dir(utils.dummy_vllm_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
else:
|
||||
@@ -136,7 +127,7 @@ else:
|
||||
from .models.stablelm import VLLMStableLM as VLLMStableLM
|
||||
from .models.starcoder import VLLMStarCoder as VLLMStarCoder
|
||||
try:
|
||||
if not utils.is_flax_available(): raise exceptions.MissingDependencyError
|
||||
if not openllm_core.utils.is_flax_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_flax_objects"] = [name for name in dir(utils.dummy_flax_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
else:
|
||||
@@ -148,7 +139,7 @@ else:
|
||||
from .models.flan_t5 import FlaxFlanT5 as FlaxFlanT5
|
||||
from .models.opt import FlaxOPT as FlaxOPT
|
||||
try:
|
||||
if not utils.is_tf_available(): raise exceptions.MissingDependencyError
|
||||
if not openllm_core.utils.is_tf_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_tf_objects"] = [name for name in dir(utils.dummy_tf_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
else:
|
||||
@@ -161,7 +152,7 @@ else:
|
||||
from .models.opt import TFOPT as TFOPT
|
||||
|
||||
# NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__
|
||||
__lazy = utils.LazyModule(__name__, _os.path.abspath("__file__"), _import_structure, extra_objects={"COMPILED": COMPILED})
|
||||
__lazy = openllm_core.utils.LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects={"COMPILED": COMPILED})
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,22 +1,21 @@
|
||||
from __future__ import annotations
|
||||
import functools, inspect, logging, os, re, traceback, types, typing as t, uuid
|
||||
import functools, inspect, logging, os, re, traceback, types, typing as t, uuid, attr, fs.path, inflection, orjson, bentoml, openllm, openllm_core, gc
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
import attr, fs.path, inflection, orjson, bentoml, openllm, gc
|
||||
from huggingface_hub import hf_hub_download
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
|
||||
from ._configuration import (
|
||||
from openllm_core._configuration import (
|
||||
FineTuneConfig,
|
||||
LLMConfig,
|
||||
_object_getattribute,
|
||||
_setattr_class,
|
||||
)
|
||||
from ._quantisation import infer_quantisation_config
|
||||
from ._schema import unmarshal_vllm_outputs
|
||||
from openllm_core._schema import unmarshal_vllm_outputs
|
||||
from .exceptions import ForbiddenAttributeError, GpuNotAvailableError, OpenLLMException
|
||||
from .models.auto import AutoConfig
|
||||
from .utils import (
|
||||
from openllm_core.utils import (
|
||||
DEBUG,
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
MYPY,
|
||||
@@ -29,7 +28,6 @@ from .utils import (
|
||||
device_count,
|
||||
first_not_none,
|
||||
generate_hash_from_file,
|
||||
infer_auto_class,
|
||||
is_peft_available,
|
||||
is_torch_available,
|
||||
non_intrusive_setattr,
|
||||
@@ -37,8 +35,8 @@ from .utils import (
|
||||
resolve_filepath,
|
||||
validate_is_path,
|
||||
)
|
||||
|
||||
from ._typing_compat import (
|
||||
from .utils import infer_auto_class
|
||||
from openllm_core._typing_compat import (
|
||||
AdaptersMapping,
|
||||
AdaptersTuple,
|
||||
AnyCallable,
|
||||
@@ -57,8 +55,8 @@ from ._typing_compat import (
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import auto_gptq as autogptq, peft, torch, transformers, vllm
|
||||
from ._configuration import PeftType
|
||||
from .utils.representation import ReprArgs
|
||||
from openllm_core._configuration import PeftType
|
||||
from openllm_core.utils.representation import ReprArgs
|
||||
else:
|
||||
autogptq = LazyLoader("autogptq", globals(), "auto_gptq")
|
||||
vllm = LazyLoader("vllm", globals(), "vllm")
|
||||
@@ -156,27 +154,6 @@ class LLMInterface(ABC, t.Generic[M, T]):
|
||||
"""The iterator version of `generate` function."""
|
||||
raise NotImplementedError("Currently generate_iterator requires SSE (Server-side events) support, which is not yet implemented.")
|
||||
|
||||
def sanitize_parameters(self, prompt: str, **attrs: t.Any) -> tuple[str, DictStrAny, DictStrAny]:
|
||||
"""This handler will sanitize all attrs and setup prompt text.
|
||||
|
||||
It takes a prompt that is given by the user, attrs that can be parsed with the prompt.
|
||||
|
||||
Returns a tuple of three items:
|
||||
- The attributes dictionary that can be passed into LLMConfig to generate a GenerationConfig
|
||||
- The attributes dictionary that will be passed into `self.postprocess_generate`.
|
||||
"""
|
||||
return prompt, attrs, attrs
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> t.Any:
|
||||
"""This handler will postprocess generation results from LLM.generate and then output nicely formatted results (if the LLM decide to do so.).
|
||||
|
||||
You can customize how the output of the LLM looks with this hook. By default, it is a simple echo.
|
||||
|
||||
> [!NOTE]
|
||||
> This will be used from the client side.
|
||||
"""
|
||||
return generation_result
|
||||
|
||||
def llm_post_init(self) -> None:
|
||||
"""This function can be implemented if you need to initialized any additional variables that doesn't concern OpenLLM internals."""
|
||||
pass
|
||||
@@ -380,9 +357,7 @@ def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]
|
||||
anns[key] = interface_anns.get(key)
|
||||
return codegen.generate_function(cls, "__assign_llm_attr", lines, args=("cls", *args), globs=globs, annotations=anns)
|
||||
|
||||
def vllm_postprocess_generate(self: LLM["vllm.LLMEngine", T], prompt: str, generation_result: list[dict[str, t.Any]], **_: t.Any) -> str:
|
||||
return generation_result[0]["outputs"][0]["text"]
|
||||
|
||||
def vllm_postprocess_generate(self: LLM["vllm.LLMEngine", T], prompt: str, generation_result: list[dict[str, t.Any]], **_: t.Any) -> str: return generation_result[0]["outputs"][0]["text"]
|
||||
def vllm_generate(self: LLM["vllm.LLMEngine", T], prompt: str, **attrs: t.Any) -> list[dict[str, t.Any]]:
|
||||
outputs: list[vllm.RequestOutput] = []
|
||||
# TODO: support prompt_token_ids
|
||||
@@ -430,8 +405,6 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
elif "config_class" not in cd: raise RuntimeError("Missing required key 'config_class'. Make sure to define it within the LLM subclass.")
|
||||
_make_assignment_script(cls)(cls)
|
||||
if "tokenizer_id" not in cd and cls.__llm_implementation__ == "vllm": cls.tokenizer_id = _DEFAULT_TOKENIZER
|
||||
|
||||
# fmt: off
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["trust_remote_code"]) -> bool: ...
|
||||
@overload
|
||||
@@ -459,24 +432,14 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
if hasattr(self, internal_attributes): return getattr(self, internal_attributes)
|
||||
elif hasattr(self, item): return getattr(self, item)
|
||||
else: raise KeyError(item)
|
||||
@classmethod
|
||||
@overload
|
||||
def from_pretrained(
|
||||
cls, model_id: str | None = ..., model_version: str | None = ..., llm_config: LLMConfig | None = ..., *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = ..., quantize: t.Literal["int8", "int4"] = ..., bettertransformer: str | bool | None = ..., adapter_id: str | None = ..., adapter_name: str | None = ..., adapter_map: dict[str, str | None] | None = ...,
|
||||
quantization_config: transformers.BitsAndBytesConfig | None = ..., serialisation: t.Literal["safetensors", "legacy"] = ..., **attrs: t.Any
|
||||
) -> LLM[M, T]: ...
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_id: str | None = ..., model_version: str | None = ..., llm_config: LLMConfig | None = ..., *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = ..., quantize: t.Literal["int8", "int4"] = ..., bettertransformer: str | bool | None = ..., adapter_id: str | None = ..., adapter_name: str | None = ..., adapter_map: dict[str, str | None] | None = ..., quantization_config: transformers.BitsAndBytesConfig | None = ..., serialisation: t.Literal["safetensors", "legacy"] = ..., **attrs: t.Any) -> LLM[M, T]: ...
|
||||
@overload
|
||||
def from_pretrained(
|
||||
cls, model_id: str | None = ..., model_version: str | None = ..., llm_config: LLMConfig | None = ..., *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = ..., quantize: t.Literal["gptq"] = ..., bettertransformer: str | bool | None = ..., adapter_id: str | None = ..., adapter_name: str | None = ..., adapter_map: dict[str, str | None] | None = ...,
|
||||
quantization_config: autogptq.BaseQuantizeConfig | None = ..., serialisation: t.Literal["safetensors", "legacy"] = ..., **attrs: t.Any
|
||||
) -> LLM[M, T]: ...
|
||||
# fmt: on
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_id: str | None = None, model_version: str | None = None, llm_config: LLMConfig | None = None, *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = None, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: str | bool | None = None, adapter_id: str | None = None, adapter_name: str | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None, quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None = None, serialisation: t.Literal["safetensors", "legacy"] = "safetensors", **attrs: t.Any,
|
||||
) -> LLM[M, T]:
|
||||
def from_pretrained(cls, model_id: str | None = ..., model_version: str | None = ..., llm_config: LLMConfig | None = ..., *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = ..., quantize: t.Literal["gptq"] = ..., bettertransformer: str | bool | None = ..., adapter_id: str | None = ..., adapter_name: str | None = ..., adapter_map: dict[str, str | None] | None = ..., quantization_config: autogptq.BaseQuantizeConfig | None = ..., serialisation: t.Literal["safetensors", "legacy"] = ..., **attrs: t.Any) -> LLM[M, T]: ...
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_id: str | None = None, model_version: str | None = None, llm_config: LLMConfig | None = None, *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = None, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: str | bool | None = None, adapter_id: str | None = None, adapter_name: str | None = None, adapter_map: dict[str, str | None] | None = None, quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None = None, serialisation: t.Literal["safetensors", "legacy"] = "safetensors", **attrs: t.Any) -> LLM[M, T]:
|
||||
"""Instantiate a pretrained LLM.
|
||||
|
||||
``LLM.from_pretrained`` follows the same design principle as HuggingFace's `from_pretrained` method, plus the following:
|
||||
@@ -708,7 +671,6 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
normalized_model_kwds, normalized_tokenizer_kwds = normalize_attrs_to_model_tokenizer_pair(**attrs)
|
||||
# NOTE: Save the args and kwargs for latter load
|
||||
self.__attrs_init__(llm_config, quantization_config, model_id, _runtime, args, {**model_kwds, **normalized_model_kwds}, {**tokenizer_kwds, **normalized_tokenizer_kwds}, _tag, _adapters_mapping, _model_version, _quantize_method, _serialisation_format, _local)
|
||||
|
||||
# handle trust_remote_code
|
||||
_from_env = os.getenv("TRUST_REMOTE_CODE", None)
|
||||
self.__llm_trust_remote_code__ = first_not_none(str(_from_env).upper() in ENV_VARS_TRUE_VALUES if _from_env else None, default=self._model_attrs.pop("trust_remote_code", self.config["trust_remote_code"]))
|
||||
@@ -723,7 +685,6 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
def __setattr__(self, attr: str, value: t.Any) -> None:
|
||||
if attr in _reserved_namespace: raise ForbiddenAttributeError(f"{attr} should not be set during runtime as these value will be reflected during runtime. Instead, you can create a custom LLM subclass {self.__class__.__name__}.")
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
@property
|
||||
def adapters_mapping(self) -> AdaptersMapping | None: return self._adapters_mapping
|
||||
@adapters_mapping.setter
|
||||
@@ -740,6 +701,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
def runtime(self) -> t.Literal["ggml", "transformers"]: return self._runtime
|
||||
@property
|
||||
def runner_name(self) -> str: return f"llm-{self.config['start_name']}-runner"
|
||||
# NOTE: The section below defines a loose contract with langchain's LLM interface.
|
||||
@property
|
||||
def llm_type(self) -> str: return normalise_model_name(self._model_id)
|
||||
@property
|
||||
@@ -755,6 +717,27 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
if self.__llm_bentomodel__ is None: self.__llm_bentomodel__ = openllm.serialisation.get(self)
|
||||
return self.__llm_bentomodel__
|
||||
|
||||
def sanitize_parameters(self, prompt: str, **attrs: t.Any) -> tuple[str, DictStrAny, DictStrAny]:
|
||||
"""This handler will sanitize all attrs and setup prompt text.
|
||||
|
||||
It takes a prompt that is given by the user, attrs that can be parsed with the prompt.
|
||||
|
||||
Returns a tuple of three items:
|
||||
- The attributes dictionary that can be passed into LLMConfig to generate a GenerationConfig
|
||||
- The attributes dictionary that will be passed into `self.postprocess_generate`.
|
||||
"""
|
||||
return self.config.sanitize_parameters(prompt, **attrs)
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> t.Any:
|
||||
"""This handler will postprocess generation results from LLM.generate and then output nicely formatted results (if the LLM decide to do so.).
|
||||
|
||||
You can customize how the output of the LLM looks with this hook. By default, it is a simple echo.
|
||||
|
||||
> [!NOTE]
|
||||
> This will be used from the client side.
|
||||
"""
|
||||
if isinstance(generation_result, dict): return generation_result["text"]
|
||||
return self.config.postprocess_generate(prompt, generation_result, **attrs)
|
||||
|
||||
@property
|
||||
def model(self) -> M:
|
||||
# Run check for GPU
|
||||
@@ -868,7 +851,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
|
||||
# order of these fields matter here, make sure to sync it with
|
||||
# openllm.models.auto.factory.BaseAutoLLMClass.for_model
|
||||
def to_runner(self, models: list[bentoml.Model] | None = None, max_batch_size: int | None = None, max_latency_ms: int | None = None, scheduling_strategy: type[bentoml.Strategy] | None = None) -> LLMRunner[M, T]:
|
||||
def to_runner(self, models: list[bentoml.Model] | None = None, max_batch_size: int | None = None, max_latency_ms: int | None = None, scheduling_strategy: type[bentoml.Strategy] = openllm_core.CascadingResourceStrategy) -> LLMRunner[M, T]:
|
||||
"""Convert this LLM into a Runner.
|
||||
|
||||
Args:
|
||||
@@ -894,10 +877,6 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
try: models.append(self._bentomodel)
|
||||
except bentoml.exceptions.NotFound as err: raise RuntimeError(f"Failed to locate {self._bentomodel}:{err}") from None
|
||||
|
||||
if scheduling_strategy is None:
|
||||
from ._strategies import CascadingResourceStrategy
|
||||
scheduling_strategy = CascadingResourceStrategy
|
||||
|
||||
generate_sig = ModelSignature.from_dict(t.cast("_ModelSignatureDict", ModelSignatureDict(batchable=False)))
|
||||
embeddings_sig = ModelSignature.from_dict(t.cast("_ModelSignatureDict", ModelSignatureDict(batchable=True, batch_dim=0)))
|
||||
generate_iterator_sig = ModelSignature.from_dict(t.cast("_ModelSignatureDict", ModelSignatureDict(batchable=False)))
|
||||
@@ -932,10 +911,6 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
for it in self.generate_iterator(prompt, **attrs): pass
|
||||
return [it]
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> str:
|
||||
if isinstance(generation_result, dict): return generation_result["text"]
|
||||
return generation_result
|
||||
|
||||
def generate_iterator(self, prompt: str, /,
|
||||
*, context_length: int | None = None, echo: bool = True, stream_interval: int = 2, stop: str | t.Iterable[str] | None = None, stop_token_ids: list[int] | None = None, **attrs: t.Any) -> t.Iterator[t.Any]:
|
||||
# NOTE: encoder-decoder models will need to implement their own generate_iterator for now
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import string, typing as t
|
||||
class PromptFormatter(string.Formatter):
|
||||
"""This PromptFormatter is largely based on langchain's implementation."""
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0: raise ValueError("Positional arguments are not supported")
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras: raise KeyError(f"Extra params passed: {extras}")
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]:
|
||||
return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
|
||||
default_formatter = PromptFormatter()
|
||||
def process_prompt(prompt: str, template: str | None = None, use_prompt_template: bool = True, **attrs: t.Any) -> str:
|
||||
# Currently, all default prompt will always have `instruction` key.
|
||||
if not use_prompt_template: return prompt
|
||||
elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'")
|
||||
template_variables = default_formatter.extract_template_variables(template)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if "instruction" in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
|
||||
try: return template.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e: raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None
|
||||
@@ -1,12 +1,11 @@
|
||||
# mypy: disable-error-code="name-defined"
|
||||
# mypy: disable-error-code="name-defined,no-redef"
|
||||
from __future__ import annotations
|
||||
import logging, sys, typing as t
|
||||
from .utils import LazyLoader, is_autogptq_available, is_bitsandbytes_available, is_transformers_supports_kbit, pkg
|
||||
if sys.version_info[:2] >= (3, 11): from typing import overload
|
||||
else: from typing_extensions import overload
|
||||
import logging, typing as t
|
||||
from openllm_core.utils import LazyLoader, is_autogptq_available, is_bitsandbytes_available, is_transformers_supports_kbit, pkg
|
||||
from openllm_core._typing_compat import overload
|
||||
if t.TYPE_CHECKING:
|
||||
from ._llm import LLM
|
||||
from ._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
autogptq, torch, transformers = LazyLoader("autogptq", globals(), "auto_gptq"), LazyLoader("torch", globals(), "torch"), LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
"""Schema definition for OpenLLM. This can be use for client interaction."""
|
||||
from __future__ import annotations
|
||||
import functools, typing as t
|
||||
import attr, inflection, openllm
|
||||
from ._configuration import GenerationConfig, LLMConfig
|
||||
from .utils import bentoml_cattr
|
||||
if t.TYPE_CHECKING: import vllm
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class GenerationInput:
|
||||
prompt: str
|
||||
llm_config: LLMConfig
|
||||
adapter_name: str | None = attr.field(default=None)
|
||||
def model_dump(self) -> dict[str, t.Any]: return {"prompt": self.prompt, "llm_config": self.llm_config.model_dump(flatten=True), "adapter_name": self.adapter_name}
|
||||
@staticmethod
|
||||
def convert_llm_config(data: dict[str, t.Any] | LLMConfig, cls: type[LLMConfig] | None = None) -> LLMConfig:
|
||||
if isinstance(data, LLMConfig): return data
|
||||
else:
|
||||
if cls is None: raise ValueError("'cls' must pass if given data is a dictionary.")
|
||||
return cls(**data)
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerationInput]: return cls.from_llm_config(openllm.AutoConfig.for_model(model_name, **attrs))
|
||||
@classmethod
|
||||
def from_llm_config(cls, llm_config: openllm.LLMConfig) -> type[GenerationInput]: return attr.make_class(inflection.camelize(llm_config["model_name"]) + "GenerationInput", attrs={"prompt": attr.field(type=str), "llm_config": attr.field(type=llm_config.__class__, default=llm_config, converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__)), "adapter_name": attr.field(default=None, type=str)})
|
||||
@attr.frozen(slots=True)
|
||||
class GenerationOutput:
|
||||
responses: t.List[t.Any]
|
||||
configuration: t.Dict[str, t.Any]
|
||||
@property
|
||||
def marshaled_config(self) -> GenerationConfig: return bentoml_cattr.structure(self.configuration, GenerationConfig)
|
||||
@property
|
||||
def unmarshaled(self) -> dict[str, t.Any]: return bentoml_cattr.unstructure(self)
|
||||
def __getitem__(self, key: str) -> t.Any:
|
||||
if hasattr(self, key): return getattr(self, key)
|
||||
elif key in self.configuration: return self.configuration[key]
|
||||
else: raise KeyError(key)
|
||||
@attr.frozen(slots=True)
|
||||
class MetadataOutput:
|
||||
model_id: str
|
||||
timeout: int
|
||||
model_name: str
|
||||
framework: str
|
||||
configuration: str
|
||||
supports_embeddings: bool
|
||||
supports_hf_agent: bool
|
||||
@attr.frozen(slots=True)
|
||||
class EmbeddingsOutput:
|
||||
embeddings: t.List[t.List[float]]
|
||||
num_tokens: int
|
||||
def unmarshal_vllm_outputs(request_output: vllm.RequestOutput) -> dict[str, t.Any]: return dict(request_id=request_output.request_id, prompt=request_output.prompt, finished=request_output.finished, prompt_token_ids=request_output.prompt_token_ids, outputs=[dict(index=it.index, text=it.text, token_ids=it.token_ids, cumulative_logprob=it.cumulative_logprob, logprobs=it.logprobs, finish_reason=it.finish_reason) for it in request_output.outputs])
|
||||
@attr.define
|
||||
class HfAgentInput:
|
||||
inputs: str
|
||||
parameters: t.Dict[str, t.Any]
|
||||
@@ -1,334 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import functools, inspect, logging, math, os, sys, types, typing as t, warnings, psutil, bentoml
|
||||
from bentoml._internal.resource import get_resource, system_resources
|
||||
from bentoml._internal.runner.strategy import THREAD_ENVS
|
||||
from .utils import DEBUG, ReprMixin
|
||||
if sys.version_info[:2] >= (3, 11): from typing import overload
|
||||
else: from typing_extensions import overload
|
||||
|
||||
class DynResource(t.Protocol):
|
||||
resource_id: t.ClassVar[str]
|
||||
@classmethod
|
||||
def from_system(cls) -> t.Sequence[t.Any]: ...
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
def _strtoul(s: str) -> int:
|
||||
"""Return -1 or positive integer sequence string starts with,."""
|
||||
if not s: return -1
|
||||
idx = 0
|
||||
for idx, c in enumerate(s):
|
||||
if not (c.isdigit() or (idx == 0 and c in "+-")): break
|
||||
if idx + 1 == len(s): idx += 1 # noqa: PLW2901
|
||||
# NOTE: idx will be set via enumerate
|
||||
return int(s[:idx]) if idx > 0 else -1
|
||||
|
||||
def _parse_list_with_prefix(lst: str, prefix: str) -> list[str]:
|
||||
rcs: list[str] = []
|
||||
for elem in lst.split(","):
|
||||
# Repeated id results in empty set
|
||||
if elem in rcs: return []
|
||||
# Anything other but prefix is ignored
|
||||
if not elem.startswith(prefix): break
|
||||
rcs.append(elem)
|
||||
return rcs
|
||||
|
||||
_STACK_LEVEL = 3
|
||||
|
||||
|
||||
@overload # variant: default callback
|
||||
def _parse_visible_devices() -> list[str] | None: ...
|
||||
@overload # variant: specify None, and respect_env
|
||||
def _parse_visible_devices(default_var: None, *, respect_env: t.Literal[True]) -> list[str] | None: ...
|
||||
@overload # variant: default var is something other than None
|
||||
def _parse_visible_devices(default_var: str = ..., *, respect_env: t.Literal[False]) -> list[str]: ...
|
||||
def _parse_visible_devices(default_var: str | None = None, respect_env: bool = True) -> list[str] | None:
|
||||
"""CUDA_VISIBLE_DEVICES aware with default var for parsing spec."""
|
||||
if respect_env:
|
||||
spec = os.environ.get("CUDA_VISIBLE_DEVICES", default_var)
|
||||
if not spec: return None
|
||||
else:
|
||||
if default_var is None: raise ValueError("spec is required to be not None when parsing spec.")
|
||||
spec = default_var
|
||||
|
||||
if spec.startswith("GPU-"): return _parse_list_with_prefix(spec, "GPU-")
|
||||
if spec.startswith("MIG-"): return _parse_list_with_prefix(spec, "MIG-")
|
||||
# XXX: We need to somehow handle cases such as '100m'
|
||||
# CUDA_VISIBLE_DEVICES uses something like strtoul
|
||||
# which makes `1gpu2,2ampere` is equivalent to `1,2`
|
||||
rc: list[int] = []
|
||||
for el in spec.split(","):
|
||||
x = _strtoul(el.strip())
|
||||
# Repeated ordinal results in empty set
|
||||
if x in rc: return []
|
||||
# Negative value aborts the sequence
|
||||
if x < 0: break
|
||||
rc.append(x)
|
||||
return [str(i) for i in rc]
|
||||
|
||||
def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
visible_devices = _parse_visible_devices()
|
||||
if visible_devices is None:
|
||||
if cls.resource_id == "amd.com/gpu":
|
||||
if not psutil.LINUX:
|
||||
if DEBUG: warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL)
|
||||
return []
|
||||
# ROCm does not currently have the rocm_smi wheel.
|
||||
# So we need to use the ctypes bindings directly.
|
||||
# we don't want to use CLI because parsing is a pain.
|
||||
sys.path.append("/opt/rocm/libexec/rocm_smi")
|
||||
try:
|
||||
from ctypes import byref, c_uint32
|
||||
|
||||
# refers to https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py
|
||||
from rsmiBindings import rocmsmi, rsmi_status_t
|
||||
|
||||
device_count = c_uint32(0)
|
||||
ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count))
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: return [str(i) for i in range(device_count.value)]
|
||||
return []
|
||||
# In this case the binary is not found, returning empty list
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
return []
|
||||
finally:
|
||||
sys.path.remove("/opt/rocm/libexec/rocm_smi")
|
||||
else:
|
||||
try:
|
||||
from cuda import cuda
|
||||
cuda.cuInit(0)
|
||||
_, dev = cuda.cuDeviceGetCount()
|
||||
return [str(i) for i in range(dev)]
|
||||
except (ImportError, RuntimeError, AttributeError):
|
||||
return []
|
||||
return visible_devices
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: int) -> list[str]: ...
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: list[int | str]) -> list[str]: ...
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: str) -> list[str]: ...
|
||||
def _from_spec(cls: type[DynResource], spec: t.Any) -> list[str]:
|
||||
if isinstance(spec, int):
|
||||
if spec in (-1, 0): return []
|
||||
if spec < -1: raise ValueError("Spec cannot be < -1.")
|
||||
return [str(i) for i in range(spec)]
|
||||
elif isinstance(spec, str):
|
||||
if not spec: return []
|
||||
if spec.isdigit(): spec = ",".join([str(i) for i in range(_strtoul(spec))])
|
||||
return _parse_visible_devices(spec, respect_env=False)
|
||||
elif isinstance(spec, list): return [str(x) for x in spec]
|
||||
else: raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.")
|
||||
|
||||
def _raw_device_uuid_nvml() -> list[str] | None:
|
||||
from ctypes import CDLL, byref, c_int, c_void_p, create_string_buffer
|
||||
|
||||
try: nvml_h = CDLL("libnvidia-ml.so.1")
|
||||
except Exception:
|
||||
warnings.warn("Failed to find nvidia binding", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
|
||||
rc = nvml_h.nvmlInit()
|
||||
if rc != 0:
|
||||
warnings.warn("Can't initialize NVML", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
dev_count = c_int(-1)
|
||||
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
|
||||
if rc != 0:
|
||||
warnings.warn("Failed to get available device from system.", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
uuids: list[str] = []
|
||||
for idx in range(dev_count.value):
|
||||
dev_id = c_void_p()
|
||||
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
|
||||
if rc != 0:
|
||||
warnings.warn(f"Failed to get device handle for {idx}", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
buf_len = 96
|
||||
buf = create_string_buffer(buf_len)
|
||||
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
|
||||
if rc != 0:
|
||||
warnings.warn(f"Failed to get device UUID for {idx}", stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
uuids.append(buf.raw.decode("ascii").strip("\0"))
|
||||
del nvml_h
|
||||
return uuids
|
||||
|
||||
def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
|
||||
if cls.resource_id == "amd.com/gpu":
|
||||
raise RuntimeError("AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'")
|
||||
if not all(isinstance(i, str) for i in val): raise ValueError("Input list should be all string type.")
|
||||
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("Failed to initialise CUDA runtime binding.")
|
||||
# correctly parse handle
|
||||
for el in val:
|
||||
if el.startswith("GPU-") or el.startswith("MIG-"):
|
||||
uuids = _raw_device_uuid_nvml()
|
||||
if uuids is None: raise ValueError("Failed to parse available GPUs UUID")
|
||||
if el not in uuids: raise ValueError(f"Given UUID {el} is not found with available UUID (available: {uuids})")
|
||||
elif el.isdigit():
|
||||
err, _ = cuda.cuDeviceGet(int(el))
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f"Failed to get device {el}")
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[DynResource]:
|
||||
return types.new_class(
|
||||
name, (bentoml.Resource[t.List[str]], ReprMixin), {"resource_id": resource_kind}, lambda ns: ns.update({"resource_id": resource_kind, "from_spec": classmethod(_from_spec), "from_system": classmethod(_from_system), "validate": classmethod(_validate), "__repr_keys__": property(lambda _: {"resource_id"}), "__doc__": inspect.cleandoc(docstring), "__module__": "openllm._strategies"}),
|
||||
)
|
||||
|
||||
# NOTE: we need to hint these t.Literal since mypy is to dumb to infer this as literal :facepalm:
|
||||
_TPU_RESOURCE: t.Literal["cloud-tpus.google.com/v2"] = "cloud-tpus.google.com/v2"
|
||||
_AMD_GPU_RESOURCE: t.Literal["amd.com/gpu"] = "amd.com/gpu"
|
||||
_NVIDIA_GPU_RESOURCE: t.Literal["nvidia.com/gpu"] = "nvidia.com/gpu"
|
||||
_CPU_RESOURCE: t.Literal["cpu"] = "cpu"
|
||||
|
||||
NvidiaGpuResource = _make_resource_class("NvidiaGpuResource", _NVIDIA_GPU_RESOURCE, """NVIDIA GPU resource.
|
||||
|
||||
This is a modified version of internal's BentoML's NvidiaGpuResource
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.""")
|
||||
AmdGpuResource = _make_resource_class("AmdGpuResource", _AMD_GPU_RESOURCE, """AMD GPU resource.
|
||||
|
||||
Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""")
|
||||
|
||||
LiteralResourceSpec = t.Literal["cloud-tpus.google.com/v2", "amd.com/gpu", "nvidia.com/gpu", "cpu"]
|
||||
|
||||
# convenient mapping
|
||||
def resource_spec(name: t.Literal["tpu", "amd", "nvidia", "cpu"]) -> LiteralResourceSpec:
|
||||
if name == "tpu": return _TPU_RESOURCE
|
||||
elif name == "amd": return _AMD_GPU_RESOURCE
|
||||
elif name == "nvidia": return _NVIDIA_GPU_RESOURCE
|
||||
elif name == "cpu": return _CPU_RESOURCE
|
||||
else: raise ValueError("Unknown alias. Accepted: ['tpu', 'amd', 'nvidia', 'cpu']")
|
||||
|
||||
@functools.lru_cache
|
||||
def available_resource_spec() -> tuple[LiteralResourceSpec, ...]:
|
||||
"""This is a utility function helps to determine the available resources from given running system.
|
||||
|
||||
It will first check for TPUs -> AMD GPUS -> NVIDIA GPUS -> CPUs.
|
||||
|
||||
TODO: Supports TPUs
|
||||
"""
|
||||
available: list[LiteralResourceSpec] = []
|
||||
if len(AmdGpuResource.from_system()) > 0: available.append(_AMD_GPU_RESOURCE)
|
||||
if len(NvidiaGpuResource.from_system()) > 0: available.append(_NVIDIA_GPU_RESOURCE)
|
||||
available.append(_CPU_RESOURCE)
|
||||
return tuple(available)
|
||||
|
||||
class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
"""This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource.
|
||||
|
||||
It also respect CUDA_VISIBLE_DEVICES for both AMD and NVIDIA GPU.
|
||||
See https://rocm.docs.amd.com/en/develop/understand/gpu_isolation.html#cuda-visible-devices
|
||||
for ROCm's support for CUDA_VISIBLE_DEVICES.
|
||||
|
||||
TODO: Support CloudTPUResource
|
||||
"""
|
||||
@classmethod
|
||||
def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: float) -> int:
|
||||
"""Return the number of workers to be used for the given runnable class.
|
||||
|
||||
Note that for all available GPU, the number of workers will always be 1.
|
||||
"""
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
# use NVIDIA
|
||||
kind = "nvidia.com/gpu"
|
||||
nvidia_req = get_resource(resource_request, kind)
|
||||
if nvidia_req is not None: return 1
|
||||
# use AMD
|
||||
kind = "amd.com/gpu"
|
||||
amd_req = get_resource(resource_request, kind, validate=False)
|
||||
if amd_req is not None: return 1
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, "cpu")
|
||||
if cpus is not None and cpus > 0:
|
||||
if "cpu" not in runnable_class.SUPPORTED_RESOURCES: logger.warning("No known supported resource available for %s, falling back to using CPU.", runnable_class)
|
||||
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError("Fractional CPU multi threading support is not yet supported.")
|
||||
return int(workers_per_resource)
|
||||
return math.ceil(cpus) * workers_per_resource
|
||||
|
||||
# this should not be reached by user since we always read system resource as default
|
||||
raise ValueError(f"No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.")
|
||||
|
||||
@classmethod
|
||||
def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]:
|
||||
"""Get worker env for this given worker_index.
|
||||
|
||||
Args:
|
||||
runnable_class: The runnable class to be run.
|
||||
resource_request: The resource request of the runnable.
|
||||
workers_per_resource: # of workers per resource.
|
||||
worker_index: The index of the worker, start from 0.
|
||||
"""
|
||||
cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
disabled = cuda_env in ("", "-1")
|
||||
environ: dict[str, t.Any] = {}
|
||||
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
# use NVIDIA
|
||||
kind = "nvidia.com/gpu"
|
||||
typ = get_resource(resource_request, kind)
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
|
||||
if disabled:
|
||||
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
|
||||
return environ
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
return environ
|
||||
# use AMD
|
||||
kind = "amd.com/gpu"
|
||||
typ = get_resource(resource_request, kind, validate=False)
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
|
||||
if disabled:
|
||||
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
|
||||
return environ
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
return environ
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, "cpu")
|
||||
if cpus is not None and cpus > 0:
|
||||
environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable gpu
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
thread_count = math.ceil(cpus)
|
||||
for thread_env in THREAD_ENVS:
|
||||
environ[thread_env] = os.environ.get(thread_env, str(thread_count))
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
return environ
|
||||
for thread_env in THREAD_ENVS:
|
||||
environ[thread_env] = os.environ.get(thread_env, "1")
|
||||
return environ
|
||||
return environ
|
||||
|
||||
@staticmethod
|
||||
def transpile_workers_to_cuda_envvar(workers_per_resource: float | int, gpus: list[str], worker_index: int) -> str:
|
||||
# Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string.
|
||||
if isinstance(workers_per_resource, float):
|
||||
# NOTE: We hit this branch when workers_per_resource is set to
|
||||
# float, for example 0.5 or 0.25
|
||||
if workers_per_resource > 1:
|
||||
raise ValueError("Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case.")
|
||||
# We are round the assigned resource here. This means if workers_per_resource=.4
|
||||
# then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
|
||||
assigned_resource_per_worker = round(1 / workers_per_resource)
|
||||
if len(gpus) < assigned_resource_per_worker:
|
||||
logger.warning("Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])", gpus, worker_index, assigned_resource_per_worker)
|
||||
raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].")
|
||||
assigned_gpu = gpus[assigned_resource_per_worker * worker_index:assigned_resource_per_worker * (worker_index+1)]
|
||||
dev = ",".join(assigned_gpu)
|
||||
else:
|
||||
idx = worker_index // workers_per_resource
|
||||
if idx >= len(gpus): raise ValueError(f"Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}")
|
||||
dev = str(gpus[idx])
|
||||
return dev
|
||||
|
||||
__all__=["CascadingResourceStrategy", "get_resource"]
|
||||
@@ -1,102 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t, bentoml, attr, abc
|
||||
from bentoml._internal.types import ModelSignatureDict as ModelSignatureDict
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm, peft, transformers, auto_gptq as autogptq, vllm
|
||||
from bentoml._internal.runner.runnable import RunnableMethod
|
||||
from bentoml._internal.runner.runner import RunnerMethod
|
||||
from bentoml._internal.runner.strategy import Strategy
|
||||
|
||||
from .bundle.oci import LiteralContainerVersionStrategy
|
||||
from .utils.lazy import VersionInfo
|
||||
|
||||
M = t.TypeVar("M", bound="t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, vllm.AsyncLLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]")
|
||||
T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]")
|
||||
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
DictStrAny = t.Dict[str, t.Any]
|
||||
ListAny = t.List[t.Any]
|
||||
ListStr = t.List[str]
|
||||
TupleAny = t.Tuple[t.Any, ...]
|
||||
At = t.TypeVar("At", bound=attr.AttrsInstance)
|
||||
|
||||
LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"]
|
||||
AdapterType = t.Literal["lora", "adalora", "adaption_prompt", "prefix_tuning", "p_tuning", "prompt_tuning", "ia3"]
|
||||
|
||||
if sys.version_info[:2] >= (3,11):
|
||||
from typing import LiteralString as LiteralString, Self as Self, overload as overload
|
||||
from typing import NotRequired as NotRequired, Required as Required, dataclass_transform as dataclass_transform
|
||||
else:
|
||||
from typing_extensions import LiteralString as LiteralString, Self as Self, overload as overload
|
||||
from typing_extensions import NotRequired as NotRequired, Required as Required, dataclass_transform as dataclass_transform
|
||||
|
||||
if sys.version_info[:2] >= (3,10):
|
||||
from typing import TypeAlias as TypeAlias, ParamSpec as ParamSpec, Concatenate as Concatenate
|
||||
else:
|
||||
from typing_extensions import TypeAlias as TypeAlias, ParamSpec as ParamSpec, Concatenate as Concatenate
|
||||
|
||||
if sys.version_info[:2] >= (3,9):
|
||||
from typing import TypedDict as TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict as TypedDict
|
||||
|
||||
class PeftAdapterOutput(TypedDict):
|
||||
success: bool
|
||||
result: t.Dict[str, peft.PeftConfig]
|
||||
error_msg: str
|
||||
|
||||
class LLMEmbeddings(t.TypedDict):
|
||||
embeddings: t.List[t.List[float]]
|
||||
num_tokens: int
|
||||
|
||||
class AdaptersTuple(TupleAny):
|
||||
adapter_id: str
|
||||
name: t.Optional[str]
|
||||
config: DictStrAny
|
||||
|
||||
AdaptersMapping = t.Dict[AdapterType, t.Tuple[AdaptersTuple, ...]]
|
||||
|
||||
class RefTuple(TupleAny):
|
||||
git_hash: str
|
||||
version: VersionInfo
|
||||
strategy: LiteralContainerVersionStrategy
|
||||
|
||||
class LLMRunnable(bentoml.Runnable, t.Generic[M, T]):
|
||||
SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu")
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
__call__: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
|
||||
set_adapter: RunnableMethod[LLMRunnable[M, T], [str], dict[t.Literal["success", "error_msg"], bool | str]]
|
||||
embeddings: RunnableMethod[LLMRunnable[M, T], [list[str]], LLMEmbeddings]
|
||||
generate: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
|
||||
generate_one: RunnableMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.Generator[str, None, str]]
|
||||
|
||||
class LLMRunner(bentoml.Runner, t.Generic[M, T]):
|
||||
__doc__: str
|
||||
__module__: str
|
||||
llm_type: str
|
||||
identifying_params: dict[str, t.Any]
|
||||
llm: openllm.LLM[M, T]
|
||||
config: openllm.LLMConfig
|
||||
implementation: LiteralRuntime
|
||||
supports_embeddings: bool
|
||||
supports_hf_agent: bool
|
||||
has_adapters: bool
|
||||
embeddings: RunnerMethod[LLMRunnable[M, T], [list[str]], t.Sequence[LLMEmbeddings]]
|
||||
generate: RunnerMethod[LLMRunnable[M, T], [str], list[t.Any]]
|
||||
generate_one: RunnerMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Generator[str, None, str]]
|
||||
def __init__(self, runnable_class: type[LLMRunnable[M, T]], *, runnable_init_params: dict[str, t.Any] | None = ..., name: str | None = ..., scheduling_strategy: type[Strategy] = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, dict[str, int]] | None = ..., embedded: bool = False,) -> None: ...
|
||||
def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
@abc.abstractmethod
|
||||
def embed(self, prompt: str | list[str]) -> LLMEmbeddings: ...
|
||||
def run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any: ...
|
||||
@abc.abstractmethod
|
||||
def download_model(self) -> bentoml.Model: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def peft_adapters(self) -> PeftAdapterOutput: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def __repr_keys__(self) -> set[str]: ...
|
||||
@@ -4,15 +4,12 @@ These utilities will stay internal, and its API can be changed or updated withou
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import os, typing as t
|
||||
from openllm.utils import LazyModule
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], "oci": ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries", "RefResolver"]}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from . import (
|
||||
_package as _package,
|
||||
oci as oci,
|
||||
)
|
||||
from . import _package as _package, oci as oci
|
||||
from ._package import (
|
||||
build_editable as build_editable,
|
||||
construct_docker_options as construct_docker_options,
|
||||
@@ -28,7 +25,7 @@ if t.TYPE_CHECKING:
|
||||
supported_registries as supported_registries,
|
||||
)
|
||||
|
||||
__lazy=LazyModule(__name__, os.path.abspath("__file__"), _import_structure)
|
||||
__all__=__lazy.__all__
|
||||
__dir__=__lazy.__dir__
|
||||
__getattr__=__lazy.__getattr__
|
||||
__lazy = LazyModule(__name__, os.path.abspath("__file__"), _import_structure)
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
@@ -1,35 +1,34 @@
|
||||
# mypy: disable-error-code="misc"
|
||||
from __future__ import annotations
|
||||
import importlib.metadata, inspect, logging, os, typing as t
|
||||
import fs, fs.copy, fs.errors, orjson, bentoml, openllm_core, importlib.metadata, inspect, logging, os, typing as t, string
|
||||
from pathlib import Path
|
||||
import fs, fs.copy, fs.errors, orjson, bentoml, openllm
|
||||
from simple_di import Provide, inject
|
||||
from bentoml._internal.bento.build_config import BentoBuildConfig, DockerOptions, ModelSpec, PythonOptions
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from . import oci
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
from fs.base import FS
|
||||
from openllm._typing_compat import LiteralString
|
||||
from openllm_core._typing_compat import LiteralString, LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from bentoml._internal.models.model import ModelStore
|
||||
from .oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_DEV_BUILD = "OPENLLM_DEV_BUILD"
|
||||
|
||||
def build_editable(path: str) -> str | None:
|
||||
def build_editable(path: str, package: t.Literal["openllm", "openllm_core", "openllm_client"] = "openllm") -> str | None:
|
||||
"""Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set."""
|
||||
if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != "true": return None
|
||||
# We need to build the package in editable mode, so that we can import it
|
||||
from build import ProjectBuilder
|
||||
from build.env import IsolatedEnvBuilder
|
||||
module_location = openllm.utils.pkg.source_locations("openllm")
|
||||
module_location = openllm_core.utils.pkg.source_locations(package)
|
||||
if not module_location: raise RuntimeError("Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.")
|
||||
pyproject_path = Path(module_location).parent.parent/"pyproject.toml"
|
||||
if os.path.isfile(pyproject_path.__fspath__()):
|
||||
logger.info("OpenLLM is installed in editable mode. Generating built wheels...")
|
||||
logger.info("Generating built wheels for package %s...", package)
|
||||
with IsolatedEnvBuilder() as env:
|
||||
builder = ProjectBuilder(pyproject_path.parent)
|
||||
builder.python_executable = env.executable
|
||||
@@ -49,15 +48,15 @@ def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_d
|
||||
|
||||
req = llm.config["requirements"]
|
||||
if req is not None: packages.extend(req)
|
||||
if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false": packages.append(f"bentoml>={'.'.join([str(i) for i in openllm.utils.pkg.pkg_version_info('bentoml')])}")
|
||||
if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false": packages.append(f"bentoml>={'.'.join([str(i) for i in openllm_core.utils.pkg.pkg_version_info('bentoml')])}")
|
||||
|
||||
env = llm.config["env"]
|
||||
framework_envvar = env["framework_value"]
|
||||
if framework_envvar == "flax":
|
||||
if not openllm.utils.is_flax_available(): raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'")
|
||||
if not openllm_core.utils.is_flax_available(): raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'")
|
||||
packages.extend([importlib.metadata.version("flax"), importlib.metadata.version("jax"), importlib.metadata.version("jaxlib")])
|
||||
elif framework_envvar == "tf":
|
||||
if not openllm.utils.is_tf_available(): raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'")
|
||||
if not openllm_core.utils.is_tf_available(): raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'")
|
||||
candidates = ("tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos",)
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
for candidate in candidates:
|
||||
@@ -68,19 +67,19 @@ def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_d
|
||||
_tf_version = importlib.metadata.version(candidate)
|
||||
packages.extend([f"tensorflow>={_tf_version}"])
|
||||
break
|
||||
except importlib.metadata.PackageNotFoundError: pass # noqa: PERF203 # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
|
||||
except importlib.metadata.PackageNotFoundError: pass # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
|
||||
else:
|
||||
if not openllm.utils.is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.")
|
||||
if not openllm_core.utils.is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.")
|
||||
packages.extend([f'torch>={importlib.metadata.version("torch")}'])
|
||||
wheels: list[str] = []
|
||||
built_wheels = build_editable(llm_fs.getsyspath("/"))
|
||||
if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}"))
|
||||
built_wheels: list[str | None] = [build_editable(llm_fs.getsyspath("/"), t.cast(t.Literal["openllm", "openllm_core", "openllm_client"], p)) for p in ("openllm_core", "openllm_client", "openllm")]
|
||||
if all(i for i in built_wheels): wheels.extend([llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in t.cast(t.List[str], built_wheels)])
|
||||
return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"])
|
||||
|
||||
def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: float, quantize: LiteralString | None, bettertransformer: bool | None, adapter_map: dict[str, str | None] | None, dockerfile_template: str | None, runtime: t.Literal["ggml", "transformers"], serialisation_format: t.Literal["safetensors", "legacy"], container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy) -> DockerOptions:
|
||||
from openllm.cli._factory import parse_config_options
|
||||
environ = parse_config_options(llm.config, llm.config["timeout"], workers_per_resource, None, True, os.environ.copy())
|
||||
env: openllm.utils.EnvVarMixin = llm.config["env"]
|
||||
env: openllm_core.utils.EnvVarMixin = llm.config["env"]
|
||||
if env["framework_value"] == "vllm": serialisation_format = "legacy"
|
||||
env_dict = {
|
||||
env.framework: env["framework_value"], env.config: f"'{llm.config.model_dump_json().decode()}'",
|
||||
@@ -91,13 +90,45 @@ def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_
|
||||
if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
|
||||
|
||||
# We need to handle None separately here, as env from subprocess doesn't accept None value.
|
||||
_env = openllm.utils.EnvVarMixin(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
|
||||
_env = openllm_core.utils.EnvVarMixin(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
|
||||
|
||||
env_dict[_env.bettertransformer] = str(_env["bettertransformer_value"])
|
||||
if _env["quantize_value"] is not None: env_dict[_env.quantize] = t.cast(str, _env["quantize_value"])
|
||||
env_dict[_env.runtime] = _env["runtime_value"]
|
||||
return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}", env=env_dict, dockerfile_template=dockerfile_template)
|
||||
|
||||
OPENLLM_MODEL_NAME = "# openllm: model name"
|
||||
OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map"
|
||||
class ModelNameFormatter(string.Formatter):
|
||||
model_keyword: LiteralString = "__model_name__"
|
||||
def __init__(self, model_name: str):
|
||||
"""The formatter that extends model_name to be formatted the 'service.py'."""
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.Any: return super().vformat(format_string, (), {self.model_keyword: self.model_name})
|
||||
def can_format(self, value: str) -> bool:
|
||||
try:
|
||||
self.parse(value)
|
||||
return True
|
||||
except ValueError: return False
|
||||
class ModelIdFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = "__model_id__"
|
||||
class ModelAdapterMapFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = "__model_adapter_map__"
|
||||
|
||||
_service_file = Path(os.path.abspath(__file__)).parent.parent/"_service.py"
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
|
||||
from openllm_core.utils import DEBUG
|
||||
model_name = llm.config["model_name"]
|
||||
logger.debug("Generating service file for %s at %s (dir=%s)", model_name, llm.config["service_name"], llm_fs.getsyspath("/"))
|
||||
with open(_service_file.__fspath__(), "r") as f: src_contents = f.readlines()
|
||||
for it in src_contents:
|
||||
if OPENLLM_MODEL_NAME in it: src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + "\n")
|
||||
elif OPENLLM_MODEL_ADAPTER_MAP in it: src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + "\n")
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents)
|
||||
if DEBUG: logger.info("Generated script:\n%s", script)
|
||||
llm_fs.writetext(llm.config["service_name"], script)
|
||||
|
||||
@inject
|
||||
def create_bento(bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.Any], workers_per_resource: str | float, quantize: LiteralString | None, bettertransformer: bool | None, dockerfile_template: str | None, adapter_map: dict[str, str | None] | None = None, extra_dependencies: tuple[str, ...] | None = None,
|
||||
runtime: t.Literal[ "ggml", "transformers"] = "transformers", serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", container_registry: LiteralContainerRegistry = "ecr", container_version_strategy: LiteralContainerVersionStrategy = "release",
|
||||
@@ -108,14 +139,14 @@ def create_bento(bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.A
|
||||
if adapter_map: labels.update(adapter_map)
|
||||
if isinstance(workers_per_resource, str):
|
||||
if workers_per_resource == "round_robin": workers_per_resource = 1.0
|
||||
elif workers_per_resource == "conserved": workers_per_resource = 1.0 if openllm.utils.device_count() == 0 else float(1 / openllm.utils.device_count())
|
||||
elif workers_per_resource == "conserved": workers_per_resource = 1.0 if openllm_core.utils.device_count() == 0 else float(1 / openllm_core.utils.device_count())
|
||||
else:
|
||||
try: workers_per_resource = float(workers_per_resource)
|
||||
except ValueError: raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
|
||||
elif isinstance(workers_per_resource, int): workers_per_resource = float(workers_per_resource)
|
||||
logger.info("Building Bento for '%s'", llm.config["start_name"])
|
||||
# add service.py definition to this temporary folder
|
||||
openllm.utils.codegen.write_service(llm, adapter_map, llm_fs)
|
||||
write_service(llm, adapter_map, llm_fs)
|
||||
|
||||
llm_spec = ModelSpec.from_item({"tag": str(llm.tag), "alias": llm.tag.name})
|
||||
build_config = BentoBuildConfig(
|
||||
@@ -134,7 +165,7 @@ def create_bento(bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.A
|
||||
if "__bento_name__" in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
|
||||
|
||||
script = "".join(service_contents)
|
||||
if openllm.utils.DEBUG: logger.info("Generated script:\n%s", script)
|
||||
if openllm_core.utils.DEBUG: logger.info("Generated script:\n%s", script)
|
||||
|
||||
bento._fs.writetext(service_fs_path, script)
|
||||
if "model_store" in inspect.signature(bento.save).parameters: return bento.save(bento_store=_bento_store, model_store=_model_store)
|
||||
|
||||
@@ -1,26 +1,23 @@
|
||||
# mypy: disable-error-code="misc"
|
||||
"""OCI-related utilities for OpenLLM. This module is considered to be internal and API are subjected to change."""
|
||||
from __future__ import annotations
|
||||
import functools, importlib, logging, os, pathlib, shutil, subprocess, typing as t
|
||||
import functools, importlib, logging, os, pathlib, shutil, subprocess, typing as t, openllm_core
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import attr, orjson, bentoml, openllm
|
||||
from openllm.utils.lazy import VersionInfo
|
||||
from openllm_core.utils.lazy import VersionInfo
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
from ghapi import all
|
||||
from openllm._typing_compat import RefTuple, LiteralString
|
||||
from openllm_core._typing_compat import RefTuple, LiteralString
|
||||
|
||||
all = openllm.utils.LazyLoader("all", globals(), "ghapi.all") # noqa: F811
|
||||
all = openllm_core.utils.LazyLoader("all", globals(), "ghapi.all") # noqa: F811
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BUILDER = bentoml.container.get_backend("buildx")
|
||||
ROOT_DIR = pathlib.Path(os.path.abspath("__file__")).parent.parent.parent
|
||||
|
||||
# TODO: support quay
|
||||
LiteralContainerRegistry = t.Literal["docker", "gh", "ecr"]
|
||||
LiteralContainerVersionStrategy = t.Literal["release", "nightly", "latest", "custom"]
|
||||
|
||||
# XXX: This registry will be hard code for now for easier to maintain
|
||||
# but in the future, we can infer based on git repo and everything to make it more options for users
|
||||
# to build the base image. For now, all of the base image will be <registry>/bentoml/openllm:...
|
||||
@@ -31,10 +28,10 @@ _CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {"docker": "docker.io
|
||||
_OWNER = "bentoml"
|
||||
_REPO = "openllm"
|
||||
|
||||
_module_location = openllm.utils.pkg.source_locations("openllm")
|
||||
_module_location = openllm_core.utils.pkg.source_locations("openllm")
|
||||
|
||||
@functools.lru_cache
|
||||
@openllm.utils.apply(str.lower)
|
||||
@openllm_core.utils.apply(str.lower)
|
||||
def get_base_container_name(reg: LiteralContainerRegistry) -> str: return _CONTAINER_REGISTRY[reg]
|
||||
|
||||
def _convert_version_from_string(s: str) -> VersionInfo: return VersionInfo.from_version_string(s)
|
||||
@@ -43,7 +40,7 @@ def _commit_time_range(r: int = 5) -> str: return (datetime.now(timezone.utc) -
|
||||
class VersionNotSupported(openllm.exceptions.OpenLLMException):
|
||||
"""Raised when the stable release is too low that it doesn't include OpenLLM base container."""
|
||||
|
||||
_RefTuple: type[RefTuple] = openllm.utils.codegen.make_attr_tuple_class("_RefTuple", ["git_hash", "version", "strategy"])
|
||||
_RefTuple: type[RefTuple] = openllm_core.utils.codegen.make_attr_tuple_class("_RefTuple", ["git_hash", "version", "strategy"])
|
||||
|
||||
def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
# NOTE: all openllm container will have sha-<git_hash[:7]>
|
||||
@@ -60,7 +57,7 @@ def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True)
|
||||
class RefResolver:
|
||||
git_hash: str = attr.field()
|
||||
version: openllm.utils.VersionInfo = attr.field(converter=_convert_version_from_string)
|
||||
version: openllm_core.utils.VersionInfo = attr.field(converter=_convert_version_from_string)
|
||||
strategy: LiteralContainerVersionStrategy = attr.field()
|
||||
_ghapi: t.ClassVar[all.GhApi] = all.GhApi(owner=_OWNER, repo=_REPO)
|
||||
@classmethod
|
||||
@@ -74,7 +71,7 @@ class RefResolver:
|
||||
version_str = meta["name"].lstrip("v")
|
||||
version: tuple[str, str | None] = (cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")["object"]["sha"], version_str)
|
||||
else: version = ("", version_str)
|
||||
if openllm.utils.VersionInfo.from_version_string(t.cast(str, version_str)) < (0, 2, 12): raise VersionNotSupported(f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'")
|
||||
if openllm_core.utils.VersionInfo.from_version_string(t.cast(str, version_str)) < (0, 2, 12): raise VersionNotSupported(f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'")
|
||||
return _RefTuple((*version, "release" if _use_base_strategy else "custom"))
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=64)
|
||||
@@ -101,7 +98,7 @@ def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralCon
|
||||
try:
|
||||
if not _BUILDER.health(): raise openllm.exceptions.Error
|
||||
except (openllm.exceptions.Error, subprocess.CalledProcessError): raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None
|
||||
if openllm.utils.device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)")
|
||||
if openllm_core.utils.device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)")
|
||||
if not shutil.which("nvidia-container-runtime"): raise RuntimeError("NVIDIA Container Toolkit is required to compile CUDA kernel in container.")
|
||||
if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
|
||||
pyproject_path = pathlib.Path(_module_location).parent.parent / "pyproject.toml"
|
||||
@@ -111,7 +108,7 @@ def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralCon
|
||||
registries = [registries] if isinstance(registries, str) else list(registries)
|
||||
tags = {name: f"{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}" for name in registries}
|
||||
try:
|
||||
outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(), context_path=pyproject_path.parent.__fspath__(), tag=tuple(tags.values()), push=push, progress="plain" if openllm.utils.get_debug_mode() else "auto", quiet=machine)
|
||||
outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(), context_path=pyproject_path.parent.__fspath__(), tag=tuple(tags.values()), push=push, progress="plain" if openllm_core.utils.get_debug_mode() else "auto", quiet=machine)
|
||||
if machine and outputs is not None: tags["image_sha"] = outputs.decode("utf-8").strip()
|
||||
except Exception as err: raise openllm.exceptions.OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err
|
||||
return tags
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from __future__ import annotations
|
||||
import functools, importlib.util, os, typing as t, logging
|
||||
import click, click_option_group as cog, inflection, orjson, bentoml, openllm
|
||||
import functools, importlib.util, os, typing as t, logging, click, click_option_group as cog, inflection, orjson, bentoml, openllm
|
||||
from click import shell_completion as sc
|
||||
from bentoml_cli.utils import BentoMLCommandGroup
|
||||
from click.shell_completion import CompletionItem
|
||||
from openllm.utils import DEBUG
|
||||
from openllm_core.utils import DEBUG
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm._typing_compat import LiteralString, DictStrAny, ParamSpec, Concatenate
|
||||
from openllm_core._typing_compat import LiteralString, DictStrAny, ParamSpec, Concatenate
|
||||
from . import termui
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import subprocess
|
||||
from openllm._configuration import LLMConfig
|
||||
from openllm_core._configuration import LLMConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,6 +20,12 @@ LiteralOutput = t.Literal["json", "pretty", "porcelain"]
|
||||
_AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar("FC", bound=t.Union[_AnyCallable, click.Command])
|
||||
|
||||
def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(str(it.tag), help="Bento") for it in bentoml.list() if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {"start_name", "bundler"})]
|
||||
|
||||
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(inflection.dasherize(it), help="Model") for it in openllm.CONFIG_MAPPING if it.startswith(incomplete)]
|
||||
|
||||
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, cors: bool, environ: DictStrAny) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "")
|
||||
@@ -316,7 +322,7 @@ def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
def model_id_option(f: _AnyCallable | None = None, *, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None, show_envvar=model_env is not None, help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f)
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-version", type=click.STRING, default=None, help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f)
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True) -> t.Callable[[FC], FC]: return cli_argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required)(f)
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required, **attrs)(f)
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--quantise", "--quantize", "quantize", type=click.Choice(["int8", "int4", "gptq"]), default=None, envvar=model_env.quantize if model_env is not None else None, show_envvar=model_env is not None, help="""Dynamic quantization for running this LLM.
|
||||
@@ -382,7 +388,7 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
)(f)
|
||||
def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--container-registry", "container_registry", type=str, default="ecr", show_default=True, show_envvar=True, envvar="OPENLLM_CONTAINER_REGISTRY", callback=container_registry_callback, help="""The default container registry to get the base image for building BentoLLM.
|
||||
"--container-registry", "container_registry", type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)), default="ecr", show_default=True, show_envvar=True, envvar="OPENLLM_CONTAINER_REGISTRY", callback=container_registry_callback, help="""The default container registry to get the base image for building BentoLLM.
|
||||
|
||||
Currently, it supports 'ecr', 'ghcr.io', 'docker.io'
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import itertools, logging, os, re, subprocess, sys, typing as t
|
||||
import bentoml, openllm
|
||||
import itertools, logging, os, re, subprocess, sys, typing as t, bentoml, openllm, openllm_core
|
||||
from simple_di import Provide, inject
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm.exceptions import OpenLLMException
|
||||
@@ -8,10 +7,9 @@ from . import termui
|
||||
from ._factory import start_command_factory
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._typing_compat import LiteralString, LiteralRuntime
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core._typing_compat import LiteralString, LiteralRuntime, LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from openllm._configuration import LLMConfig
|
||||
from openllm.bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -58,7 +56,7 @@ def _start(model_name: str, /, *, model_id: str | None = None, timeout: int = 30
|
||||
"""
|
||||
from .entrypoint import start_command, start_grpc_command
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
_ModelEnv = openllm.utils.EnvVarMixin(model_name, openllm.utils.first_not_none(framework, default=llm_config.default_implementation()), model_id=model_id, bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
|
||||
_ModelEnv = openllm_core.utils.EnvVarMixin(model_name, openllm_core.utils.first_not_none(framework, default=llm_config.default_implementation()), model_id=model_id, bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
|
||||
os.environ[_ModelEnv.framework] = _ModelEnv["framework_value"]
|
||||
|
||||
args: list[str] = ["--runtime", runtime]
|
||||
@@ -203,5 +201,5 @@ def _list_models() -> dict[str, t.Any]:
|
||||
return models_command.main(args=["-o", "json", "--show-available", "--machine"], standalone_mode=False)
|
||||
|
||||
|
||||
start, start_grpc, build, import_model, list_models = openllm.utils.codegen.gen_sdk(_start, _serve_grpc=False), openllm.utils.codegen.gen_sdk(_start, _serve_grpc=True), openllm.utils.codegen.gen_sdk(_build), openllm.utils.codegen.gen_sdk(_import_model), openllm.utils.codegen.gen_sdk(_list_models)
|
||||
start, start_grpc, build, import_model, list_models = openllm_core.utils.codegen.gen_sdk(_start, _serve_grpc=False), openllm_core.utils.codegen.gen_sdk(_start, _serve_grpc=True), openllm_core.utils.codegen.gen_sdk(_build), openllm_core.utils.codegen.gen_sdk(_import_model), openllm_core.utils.codegen.gen_sdk(_list_models)
|
||||
__all__ = ["start", "start_grpc", "build", "import_model", "list_models"]
|
||||
|
||||
@@ -20,10 +20,9 @@ bentomodel = openllm.import_model("falcon", model_id='tiiuae/falcon-7b-instruct'
|
||||
```
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import functools, http.client, inspect, itertools, logging, os, platform, re, subprocess, sys, time, traceback, typing as t
|
||||
import attr, click, click_option_group as cog, fs, fs.copy, fs.errors, inflection, orjson, bentoml, openllm
|
||||
from bentoml_cli.utils import BentoMLCommandGroup, opt_callback
|
||||
import functools, http.client, inspect, itertools, logging, os, platform, re, subprocess, sys, time, traceback, typing as t, attr, click, click_option_group as cog, fs, fs.copy, fs.errors, inflection, orjson, bentoml, openllm
|
||||
from simple_di import Provide, inject
|
||||
from bentoml_cli.utils import BentoMLCommandGroup, opt_callback
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelStore
|
||||
from . import termui
|
||||
@@ -56,8 +55,8 @@ from openllm.models.auto import (
|
||||
AutoConfig,
|
||||
AutoLLM,
|
||||
)
|
||||
from openllm._typing_compat import DictStrAny, ParamSpec, Concatenate, LiteralString, Self, LiteralRuntime
|
||||
from openllm.utils import (
|
||||
from openllm_core._typing_compat import DictStrAny, ParamSpec, Concatenate, LiteralString, Self, LiteralRuntime
|
||||
from openllm_core.utils import (
|
||||
DEBUG,
|
||||
DEBUG_ENV_VAR,
|
||||
OPTIONAL_DEPENDENCIES,
|
||||
@@ -72,21 +71,20 @@ from openllm.utils import (
|
||||
first_not_none,
|
||||
get_debug_mode,
|
||||
get_quiet_mode,
|
||||
infer_auto_class,
|
||||
is_torch_available,
|
||||
is_transformers_supports_agent,
|
||||
resolve_user_filepath,
|
||||
set_debug_mode,
|
||||
set_quiet_mode,
|
||||
)
|
||||
from openllm.utils import infer_auto_class
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from bentoml._internal.container import DefaultBuilder
|
||||
from openllm.client import BaseClient
|
||||
from openllm._schema import EmbeddingsOutput
|
||||
from openllm.bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
from openllm_core._schema import EmbeddingsOutput
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
else: torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
P = ParamSpec("P")
|
||||
@@ -271,7 +269,7 @@ def cli() -> None:
|
||||
\b
|
||||
An open platform for operating large language models in production.
|
||||
Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
||||
""" # noqa: D205
|
||||
"""
|
||||
|
||||
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name="start", aliases=["start-http"])
|
||||
def start_command() -> None:
|
||||
@@ -670,10 +668,8 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output:
|
||||
"""
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout)
|
||||
|
||||
try:
|
||||
client.call("metadata")
|
||||
except http.client.BadStatusLine:
|
||||
raise click.ClickException(f"{endpoint} is neither a HTTP server nor reachable.") from None
|
||||
try: client.call("metadata")
|
||||
except http.client.BadStatusLine: raise click.ClickException(f"{endpoint} is neither a HTTP server nor reachable.") from None
|
||||
if agent == "hf":
|
||||
if not is_transformers_supports_agent(): raise click.UsageError("Transformers version should be at least 4.29 to support HfAgent. Upgrade with 'pip install -U transformers'")
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
@@ -700,7 +696,7 @@ def embed_command(ctx: click.Context, text: tuple[str, ...], endpoint: str, time
|
||||
$ openllm embed --endpoint http://12.323.2.1:3000 "What is the meaning of life?" "How many stars are there in the sky?"
|
||||
```
|
||||
"""
|
||||
client = t.cast("BaseClient[t.Any]", openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout))
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout)
|
||||
try:
|
||||
gen_embed = client.embed(text)
|
||||
except ValueError:
|
||||
@@ -733,14 +729,14 @@ def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: in
|
||||
"""
|
||||
_memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v}
|
||||
if server_type == "grpc": endpoint = re.sub(r"http://", "", endpoint)
|
||||
client = t.cast("BaseClient[t.Any]", openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout))
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout)
|
||||
input_fg, generated_fg = "magenta", "cyan"
|
||||
if output != "porcelain":
|
||||
termui.echo("==Input==\n", fg="white")
|
||||
termui.echo(f"{prompt}", fg=input_fg)
|
||||
res = client.query(prompt, return_response="raw", **{**client.configuration, **_memoized})
|
||||
if output == "pretty":
|
||||
response = client.llm.postprocess_generate(prompt, res["responses"])
|
||||
response = client.config.postprocess_generate(prompt, res["responses"])
|
||||
termui.echo("\n\n==Responses==\n", fg="white")
|
||||
termui.echo(response, fg=generated_fg)
|
||||
elif output == "json":
|
||||
|
||||
@@ -1,37 +1,26 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import orjson
|
||||
|
||||
import openllm
|
||||
|
||||
from .. import termui
|
||||
from .._factory import machine_option
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm.bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
|
||||
import typing as t, click, orjson, openllm
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import machine_option, container_registry_option
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
@click.command(
|
||||
"build_base_container", context_settings=termui.CONTEXT_SETTINGS, help="""Base image builder for BentoLLM.
|
||||
|
||||
By default, the base image will include custom kernels (PagedAttention via vllm, FlashAttention-v2, etc.) built with CUDA 11.8, Python 3.9 on Ubuntu22.04.
|
||||
|
||||
Optionally, this can also be pushed directly to remote registry. Currently support ``docker.io``, ``ghcr.io`` and ``quay.io``.
|
||||
|
||||
\b
|
||||
If '--machine' is passed, then it will run the process quietly, and output a JSON to the current running terminal.
|
||||
|
||||
This command is only useful for debugging and for building custom base image for extending BentoML with custom base images and custom kernels.
|
||||
|
||||
Note that we already release images on our CI to ECR and GHCR, so you don't need to build it yourself.
|
||||
"""
|
||||
)
|
||||
@click.option("--registry", multiple=True, type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)), help="Target registry to create image tag on.", default=None)
|
||||
@container_registry_option
|
||||
@click.option("--version-strategy", type=click.Choice(["release", "latest", "nightly"]), default="nightly", help="Version strategy to use for tagging the image.")
|
||||
@click.option("--push/--no-push", help="Whether to push to remote repository", is_flag=True, default=False)
|
||||
@machine_option
|
||||
def cli(registry: tuple[LiteralContainerRegistry, ...] | None, version_strategy: LiteralContainerVersionStrategy, push: bool, machine: bool) -> dict[str, str]:
|
||||
mapping = openllm.bundle.build_container(registry, version_strategy, push, machine)
|
||||
def cli(container_registry: tuple[LiteralContainerRegistry, ...] | None, version_strategy: LiteralContainerVersionStrategy, push: bool, machine: bool) -> dict[str, str]:
|
||||
mapping = openllm.bundle.build_container(container_registry, version_strategy, push, machine)
|
||||
if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
return mapping
|
||||
|
||||
@@ -1,24 +1,16 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import shutil
|
||||
import subprocess
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import psutil
|
||||
import shutil, subprocess, typing as t, click, psutil, bentoml
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
|
||||
from .. import termui
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import bento_complete_envvar, machine_option
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
if t.TYPE_CHECKING: from bentoml._internal.bento import BentoStore
|
||||
|
||||
@click.command("dive_bentos", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument("bento", type=str)
|
||||
@click.option("--machine", is_flag=True, default=False, hidden=True)
|
||||
@click.argument("bento", type=str, shell_complete=bento_complete_envvar)
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
|
||||
@@ -32,5 +24,5 @@ def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore
|
||||
if machine: return bentomodel.path
|
||||
# copy and paste this into a new shell
|
||||
if psutil.WINDOWS: subprocess.check_call([shutil.which("dir") or "dir"], cwd=bentomodel.path)
|
||||
else: subprocess.check_call([shutil.which("tree") or "tree"], cwd=bentomodel.path)
|
||||
else: subprocess.check_call([shutil.which("ls") or "ls", "-Rrthla"], cwd=bentomodel.path)
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import typing as t, click, bentoml
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
from bentoml._internal.bento.bento import BentoInfo
|
||||
from bentoml._internal.bento.build_config import DockerOptions
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.container.generate import generate_containerfile
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import bento_complete_envvar
|
||||
from openllm_core.utils import bentoml_cattr
|
||||
|
||||
from .. import termui
|
||||
from ...utils import bentoml_cattr
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
if t.TYPE_CHECKING: from bentoml._internal.bento import BentoStore
|
||||
|
||||
@click.command("get_containerfile", context_settings=termui.CONTEXT_SETTINGS, help="Return Containerfile of any given Bento.")
|
||||
@click.argument("bento", type=str)
|
||||
@click.argument("bento", type=str, shell_complete=bento_complete_envvar)
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str:
|
||||
|
||||
@@ -1,25 +1,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
import typing as t, click, inflection, orjson, openllm
|
||||
from bentoml_cli.utils import opt_callback
|
||||
|
||||
import openllm
|
||||
|
||||
from .. import termui
|
||||
from ..._prompt import process_prompt
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import model_complete_envvar, output_option, machine_option
|
||||
from openllm_core._prompt import process_prompt
|
||||
|
||||
LiteralOutput = t.Literal["json", "pretty", "porcelain"]
|
||||
|
||||
@click.command("get_prompt", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]))
|
||||
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]), shell_complete=model_complete_envvar)
|
||||
@click.argument("prompt", type=click.STRING)
|
||||
@click.option("-o", "--output", "output", type=click.Choice(["json", "pretty", "porcelain"]), default="pretty", help="Showing output type.", show_default=True, envvar="OPENLLM_OUTPUT", show_envvar=True)
|
||||
@output_option
|
||||
@click.option("--format", type=click.STRING, default=None)
|
||||
@click.option("--machine", is_flag=True, default=False, hidden=True)
|
||||
@machine_option
|
||||
@click.option("--opt", help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)", required=False, multiple=True, callback=opt_callback, metavar="ARG=VALUE[,ARG=VALUE]")
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, output: LiteralOutput, machine: bool, _memoized: dict[str, t.Any], **_: t.Any) -> str | None:
|
||||
|
||||
@@ -1,16 +1,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import click, inflection, orjson, bentoml, openllm
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
|
||||
from .. import termui
|
||||
from .._factory import LiteralOutput, output_option
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import LiteralOutput, output_option
|
||||
|
||||
@click.command("list_bentos", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@output_option(default_value="json")
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, bentoml, openllm, orjson, inflection ,click
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import LiteralOutput, model_name_argument, output_option
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
from openllm.cli._factory import LiteralOutput, model_name_argument, output_option, model_complete_envvar
|
||||
|
||||
if t.TYPE_CHECKING: from openllm._typing_compat import DictStrAny
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
@click.command("list_models", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@model_name_argument(required=False)
|
||||
@model_name_argument(required=False, shell_complete=model_complete_envvar)
|
||||
@output_option(default_value="json")
|
||||
def cli(model_name: str | None, output: LiteralOutput) -> DictStrAny:
|
||||
"""This is equivalent to openllm models --show-available less the nice table."""
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from __future__ import annotations
|
||||
import importlib.machinery, logging, os, pkgutil, subprocess, sys, tempfile, typing as t
|
||||
import click, yaml
|
||||
import importlib.machinery, logging, os, pkgutil, subprocess, sys, tempfile, typing as t, click, yaml
|
||||
from openllm.cli import termui
|
||||
from openllm import playground
|
||||
from openllm.utils import is_jupyter_available, is_jupytext_available, is_notebook_available
|
||||
from openllm_core.utils import is_jupyter_available, is_jupytext_available, is_notebook_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import jupytext, nbformat
|
||||
from openllm._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,7 +37,7 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
\b
|
||||
> [!NOTE]
|
||||
> This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"'
|
||||
""" # noqa: D301
|
||||
"""
|
||||
if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available():
|
||||
raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'")
|
||||
metadata = load_notebook_metadata()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import os, typing as t, click, inflection, openllm
|
||||
if t.TYPE_CHECKING: from openllm._typing_compat import DictStrAny
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
def echo(text: t.Any, fg: str = "green", _with_style: bool = True, **attrs: t.Any) -> None:
|
||||
attrs["fg"] = fg if not openllm.utils.get_debug_mode() else None
|
||||
|
||||
17
openllm-python/src/openllm/client.py
Normal file
17
openllm-python/src/openllm/client.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""OpenLLM Python client.
|
||||
|
||||
```python
|
||||
client = openllm.client.HTTPClient("http://localhost:8080")
|
||||
client.query("What is the difference between gather and scatter?")
|
||||
```
|
||||
|
||||
If the server has embedding supports, use it via `client.embed`:
|
||||
```python
|
||||
client.embed("What is the difference between gather and scatter?")
|
||||
```
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import openllm_client, typing as t
|
||||
if t.TYPE_CHECKING: from openllm_client import AsyncHTTPClient as AsyncHTTPClient, BaseAsyncClient as BaseAsyncClient, BaseClient as BaseClient, HTTPClient as HTTPClient, GrpcClient as GrpcClient, AsyncGrpcClient as AsyncGrpcClient
|
||||
def __dir__() -> t.Sequence[str]: return sorted(dir(openllm_client))
|
||||
def __getattr__(it: str) -> t.Any: return getattr(openllm_client, it)
|
||||
@@ -1,22 +0,0 @@
|
||||
"""OpenLLM Python client.
|
||||
|
||||
```python
|
||||
client = openllm.client.HTTPClient("http://localhost:8080")
|
||||
client.query("What is the difference between gather and scatter?")
|
||||
```
|
||||
|
||||
If the server has embedding supports, use it via `client.embed`:
|
||||
```python
|
||||
client.embed("What is the difference between gather and scatter?")
|
||||
```
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from openllm.client.runtimes import (
|
||||
AsyncGrpcClient as AsyncGrpcClient,
|
||||
AsyncHTTPClient as AsyncHTTPClient,
|
||||
BaseAsyncClient as BaseAsyncClient,
|
||||
BaseClient as BaseClient,
|
||||
GrpcClient as GrpcClient,
|
||||
HTTPClient as HTTPClient,
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Client that supports REST/gRPC protocol to interact with a LLMServer."""
|
||||
from __future__ import annotations
|
||||
|
||||
from openllm.client.runtimes.base import (
|
||||
BaseAsyncClient as BaseAsyncClient,
|
||||
BaseClient as BaseClient,
|
||||
)
|
||||
from openllm.client.runtimes.grpc import (
|
||||
AsyncGrpcClient as AsyncGrpcClient,
|
||||
GrpcClient as GrpcClient,
|
||||
)
|
||||
from openllm.client.runtimes.http import (
|
||||
AsyncHTTPClient as AsyncHTTPClient,
|
||||
HTTPClient as HTTPClient,
|
||||
)
|
||||
@@ -1,238 +0,0 @@
|
||||
# mypy: disable-error-code="name-defined"
|
||||
from __future__ import annotations
|
||||
import asyncio, logging, typing as t
|
||||
import bentoml, bentoml.client, openllm, httpx
|
||||
from abc import abstractmethod
|
||||
from http import HTTPStatus
|
||||
from urllib.parse import urljoin
|
||||
from openllm._typing_compat import overload, LiteralString
|
||||
|
||||
T = t.TypeVar("T")
|
||||
T_co = t.TypeVar("T_co", covariant=True)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
from openllm._typing_compat import DictStrAny, LiteralRuntime
|
||||
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
class AnnotatedClient(t.Protocol[T_co]):
|
||||
server_url: str
|
||||
_svc: bentoml.Service
|
||||
endpoints: list[str]
|
||||
def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: ...
|
||||
async def async_health(self) -> t.Any: ...
|
||||
def generate_v1(self, qa: openllm.GenerationInput) -> T_co: ...
|
||||
def metadata_v1(self) -> T_co: ...
|
||||
def embeddings_v1(self) -> t.Sequence[float]: ...
|
||||
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> T_co: ...
|
||||
async def async_call(self, name: str, *args: t.Any, **attrs: t.Any) -> T_co: ...
|
||||
@staticmethod
|
||||
def wait_until_server_ready(host: str, port: int, timeout: float = 30, **kwargs: t.Any) -> None: ...
|
||||
@staticmethod
|
||||
def from_url(server_url: str) -> AnnotatedClient[t.Any]: ...
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def in_async_context() -> bool:
|
||||
try:
|
||||
_ = asyncio.get_running_loop()
|
||||
return True
|
||||
except RuntimeError: return False
|
||||
|
||||
class ClientMeta(t.Generic[T]):
|
||||
_api_version: str
|
||||
_client_type: t.Literal["GrpcClient", "HTTPClient"]
|
||||
_host: str
|
||||
_port: str
|
||||
|
||||
__client__: AnnotatedClient[T] | None = None
|
||||
__agent__: transformers.HfAgent | None = None
|
||||
__llm__: openllm.LLM[t.Any, t.Any] | None = None
|
||||
|
||||
def __init__(self, address: str, timeout: int = 30): self._address,self._timeout = address,timeout
|
||||
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"): cls._client_type, cls._api_version = "HTTPClient" if client_type == "http" else "GrpcClient", api_version
|
||||
@property
|
||||
def _hf_agent(self) -> transformers.HfAgent:
|
||||
if not self.supports_hf_agent: raise openllm.exceptions.OpenLLMException(f"{self.model_name} ({self.framework}) does not support running HF agent.")
|
||||
if self.__agent__ is None:
|
||||
if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'")
|
||||
self.__agent__ = transformers.HfAgent(urljoin(self._address, "/hf/agent"))
|
||||
return self.__agent__
|
||||
@property
|
||||
def _metadata(self) -> T: return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json() if in_async_context() else self.call("metadata")
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_name(self) -> str: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def framework(self) -> LiteralRuntime: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def timeout(self) -> int: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_id(self) -> str: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def configuration(self) -> dict[str, t.Any]: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def supports_embeddings(self) -> bool: raise NotImplementedError
|
||||
@property
|
||||
@abstractmethod
|
||||
def supports_hf_agent(self) -> bool: raise NotImplementedError
|
||||
@abstractmethod
|
||||
def postprocess(self, result: t.Any) -> openllm.GenerationOutput: ...
|
||||
@abstractmethod
|
||||
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ...
|
||||
|
||||
@property
|
||||
def config(self) -> openllm.LLMConfig: return self.llm.config
|
||||
@property
|
||||
def llm(self) -> openllm.LLM[t.Any, t.Any]:
|
||||
# XXX: if the server runs vllm or any framework that is not available from the user client, client will fail.
|
||||
if self.__llm__ is None: self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name)
|
||||
return self.__llm__
|
||||
|
||||
def call(self, name: str, *args: t.Any, **attrs: t.Any) -> T: return self._cached.call(f"{name}_{self._api_version}", *args, **attrs)
|
||||
async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> T: return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs)
|
||||
@property
|
||||
def _cached(self) -> AnnotatedClient[T]:
|
||||
client_class = t.cast(AnnotatedClient[T], getattr(bentoml.client, self._client_type))
|
||||
if self.__client__ is None:
|
||||
client_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout)
|
||||
self.__client__ = client_class.from_url(self._address)
|
||||
return self.__client__
|
||||
|
||||
class BaseClient(ClientMeta[T]):
|
||||
def health(self) -> t.Any: raise NotImplementedError
|
||||
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
|
||||
def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
|
||||
def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
|
||||
return_raw_response = attrs.pop("return_raw_response", None)
|
||||
if return_raw_response is not None:
|
||||
logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
|
||||
if return_raw_response is True: return_response = "raw"
|
||||
return_attrs = attrs.pop("return_attrs", None)
|
||||
if return_attrs is not None:
|
||||
logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
|
||||
if return_attrs is True: return_response = "attrs"
|
||||
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.llm.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
|
||||
|
||||
inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
|
||||
if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/generate"), json=inputs.model_dump(), timeout=self.timeout).json()
|
||||
else: result = self.call("generate", inputs.model_dump())
|
||||
r = self.postprocess(result)
|
||||
if return_response == "attrs": return r
|
||||
elif return_response == "raw": return openllm.utils.bentoml_cattr.unstructure(r)
|
||||
else: return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
|
||||
|
||||
# NOTE: Scikit interface
|
||||
@overload
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
|
||||
@overload
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
|
||||
@overload
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
|
||||
def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], self.query(prompt, **attrs))
|
||||
|
||||
def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = "hf", **attrs: t.Any) -> t.Any:
|
||||
if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
|
||||
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
|
||||
task = kwargs.pop("task", args[0])
|
||||
return_code = kwargs.pop("return_code", False)
|
||||
remote = kwargs.pop("remote", False)
|
||||
try: return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs)
|
||||
except Exception as err:
|
||||
logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err)
|
||||
logger.info("Tip: LLMServer at '%s' might not support 'generate_one'.", self._address)
|
||||
|
||||
class BaseAsyncClient(ClientMeta[T]):
|
||||
async def health(self) -> t.Any: raise NotImplementedError
|
||||
async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError
|
||||
async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError
|
||||
@overload
|
||||
async def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
|
||||
@overload
|
||||
async def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
|
||||
@overload
|
||||
async def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
|
||||
async def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str:
|
||||
return_raw_response = attrs.pop("return_raw_response", None)
|
||||
if return_raw_response is not None:
|
||||
logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
|
||||
if return_raw_response is True: return_response = "raw"
|
||||
return_attrs = attrs.pop("return_attrs", None)
|
||||
if return_attrs is not None:
|
||||
logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
|
||||
if return_attrs is True: return_response = "attrs"
|
||||
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.llm.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
|
||||
|
||||
inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
|
||||
res = await self.acall("generate", inputs.model_dump())
|
||||
r = self.postprocess(res)
|
||||
|
||||
if return_response == "attrs": return r
|
||||
elif return_response == "raw": return openllm.utils.bentoml_cattr.unstructure(r)
|
||||
else: return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
|
||||
|
||||
# NOTE: Scikit interface
|
||||
@overload
|
||||
async def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ...
|
||||
@overload
|
||||
async def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ...
|
||||
@overload
|
||||
async def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ...
|
||||
async def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], await self.query(prompt, **attrs))
|
||||
async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = "hf", **attrs: t.Any) -> t.Any:
|
||||
"""Async version of agent.run."""
|
||||
if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0")
|
||||
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
|
||||
task = kwargs.pop("task", args[0])
|
||||
return_code = kwargs.pop("return_code", False)
|
||||
remote = kwargs.pop("remote", False)
|
||||
|
||||
from transformers.tools.agents import clean_code_for_run, get_tool_creation_code, resolve_tools
|
||||
from transformers.tools.python_interpreter import evaluate
|
||||
|
||||
_hf_agent = self._hf_agent
|
||||
|
||||
prompt = t.cast(str, _hf_agent.format_prompt(task))
|
||||
stop = ["Task:"]
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout)) as client:
|
||||
response = await client.post(_hf_agent.url_endpoint, json={"inputs": prompt, "parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},},)
|
||||
if response.status_code != HTTPStatus.OK:
|
||||
raise ValueError(f"Error {response.status_code}: {response.json()}")
|
||||
|
||||
result = response.json()[0]["generated_text"]
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq):
|
||||
result = result[:-len(stop_seq)]
|
||||
break
|
||||
|
||||
# the below have the same logic as agent.run API
|
||||
explanation, code = clean_code_for_run(result)
|
||||
_hf_agent.log(f"==Explanation from the agent==\n{explanation}")
|
||||
_hf_agent.log(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
_hf_agent.log("\n\n==Result==")
|
||||
_hf_agent.cached_tools = resolve_tools(code, _hf_agent.toolbox, remote=remote, cached_tools=_hf_agent.cached_tools)
|
||||
return evaluate(code, _hf_agent.cached_tools, state=kwargs.copy())
|
||||
else:
|
||||
tool_code = get_tool_creation_code(code, _hf_agent.toolbox, remote=remote)
|
||||
return f"{tool_code}\n{code}"
|
||||
@@ -1,93 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import asyncio, logging, typing as t
|
||||
import orjson, openllm
|
||||
from openllm._typing_compat import LiteralRuntime
|
||||
from .base import BaseAsyncClient, BaseClient
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from grpc_health.v1 import health_pb2
|
||||
from bentoml.grpc.v1.service_pb2 import Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GrpcClient(BaseClient["Response"], client_type="grpc"):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(":")
|
||||
super().__init__(address, timeout)
|
||||
def health(self) -> health_pb2.HealthCheckResponse: return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try: return self._metadata.json.struct_value.fields["model_name"].string_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try:
|
||||
value = t.cast(LiteralRuntime, self._metadata.json.struct_value.fields["framework"].string_value)
|
||||
if value not in ("pt", "flax", "tf", "vllm"): raise KeyError
|
||||
return value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try: return int(self._metadata.json.struct_value.fields["timeout"].number_value)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try: return self._metadata.json.struct_value.fields["model_id"].string_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try: return orjson.loads(self._metadata.json.struct_value.fields["configuration"].string_value)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
try: return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_hf_agent(self) -> bool:
|
||||
try: return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
if isinstance(result, dict): return openllm.GenerationOutput(**result)
|
||||
return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))
|
||||
|
||||
class AsyncGrpcClient(BaseAsyncClient["Response"], client_type="grpc"):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(":")
|
||||
super().__init__(address, timeout)
|
||||
async def health(self) -> health_pb2.HealthCheckResponse: return await self._cached.health("bentoml.grpc.v1.BentoService")
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try: return self._metadata.json.struct_value.fields["model_name"].string_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try:
|
||||
value = t.cast(LiteralRuntime, self._metadata.json.struct_value.fields["framework"].string_value)
|
||||
if value not in ("pt", "flax", "tf", "vllm"): raise KeyError
|
||||
return value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try: return int(self._metadata.json.struct_value.fields["timeout"].number_value)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try: return self._metadata.json.struct_value.fields["model_id"].string_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try: return orjson.loads(self._metadata.json.struct_value.fields["configuration"].string_value)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
try: return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_hf_agent(self) -> bool:
|
||||
try: return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
if isinstance(result, dict): return openllm.GenerationOutput(**result)
|
||||
return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))
|
||||
@@ -1,96 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t
|
||||
from urllib.parse import urljoin, urlparse
|
||||
import httpx, orjson, openllm
|
||||
from .base import BaseAsyncClient, BaseClient, in_async_context
|
||||
from openllm._typing_compat import DictStrAny, LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
def process_address(self: AsyncHTTPClient | HTTPClient, address: str) -> None:
|
||||
address = address if "://" in address else "http://" + address
|
||||
parsed = urlparse(address)
|
||||
self._host, *_port = parsed.netloc.split(":")
|
||||
if len(_port) == 0: self._port = "80" if parsed.scheme == "http" else "443"
|
||||
else: self._port = next(iter(_port))
|
||||
|
||||
class HTTPClient(BaseClient[DictStrAny]):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
process_address(self, address)
|
||||
super().__init__(address, timeout)
|
||||
|
||||
def health(self) -> t.Any: return self._cached.health()
|
||||
def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
|
||||
if isinstance(prompt, str): prompt = [prompt]
|
||||
result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=list(prompt), timeout=self.timeout).json() if in_async_context() else self.call("embeddings", list(prompt))
|
||||
return openllm.EmbeddingsOutput(**result)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try: return self._metadata["model_name"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try: return self._metadata["model_name"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try: return self._metadata["framework"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try: return self._metadata["timeout"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try: return orjson.loads(self._metadata["configuration"])
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
try: return self._metadata.get("supports_embeddings", False)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_hf_agent(self) -> bool:
|
||||
try: return self._metadata.get("supports_hf_agent", False)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(**result)
|
||||
|
||||
class AsyncHTTPClient(BaseAsyncClient[DictStrAny]):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
process_address(self, address)
|
||||
super().__init__(address, timeout)
|
||||
|
||||
async def health(self) -> t.Any: return await self._cached.async_health()
|
||||
async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput:
|
||||
if isinstance(prompt, str): prompt = [prompt]
|
||||
res = await self.acall("embeddings", list(prompt))
|
||||
return openllm.EmbeddingsOutput(**res)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try: return self._metadata["model_name"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try: return self._metadata["model_name"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try: return self._metadata["framework"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try: return self._metadata["timeout"]
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try: return orjson.loads(self._metadata["configuration"])
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
try: return self._metadata.get("supports_embeddings", False)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
@property
|
||||
def supports_hf_agent(self) -> bool:
|
||||
try: return self._metadata.get("supports_hf_agent", False)
|
||||
except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(**result)
|
||||
@@ -1,19 +1,3 @@
|
||||
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
|
||||
from __future__ import annotations
|
||||
import bentoml
|
||||
class OpenLLMException(bentoml.exceptions.BentoMLException):
|
||||
"""Base class for all OpenLLM exceptions. This extends BentoMLException."""
|
||||
class GpuNotAvailableError(OpenLLMException):
|
||||
"""Raised when there is no GPU available in given system."""
|
||||
class ValidationError(OpenLLMException):
|
||||
"""Raised when a validation fails."""
|
||||
class ForbiddenAttributeError(OpenLLMException):
|
||||
"""Raised when using an _internal field."""
|
||||
class MissingAnnotationAttributeError(OpenLLMException):
|
||||
"""Raised when a field under openllm.LLMConfig is missing annotations."""
|
||||
class MissingDependencyError(BaseException):
|
||||
"""Raised when a dependency is missing."""
|
||||
class Error(BaseException):
|
||||
"""To be used instead of naked raise."""
|
||||
class FineTuneStrategyNotSupportedError(OpenLLMException):
|
||||
"""Raised when a fine-tune strategy is not supported for given LLM."""
|
||||
from openllm_core.exceptions import OpenLLMException as OpenLLMException, GpuNotAvailableError as GpuNotAvailableError, ValidationError as ValidationError, ForbiddenAttributeError as ForbiddenAttributeError, MissingAnnotationAttributeError as MissingAnnotationAttributeError, MissingDependencyError as MissingDependencyError, Error as Error, FineTuneStrategyNotSupportedError as FineTuneStrategyNotSupportedError
|
||||
|
||||
10
openllm-python/src/openllm/models/__init__.py
generated
10
openllm-python/src/openllm/models/__init__.py
generated
@@ -1,11 +1,11 @@
|
||||
# This file is generated by tools/update-models-import.py. DO NOT EDIT MANUALLY!
|
||||
# To update this, run ./tools/update-models-import.py
|
||||
from __future__ import annotations
|
||||
import typing as t, os
|
||||
from openllm.utils import LazyModule
|
||||
_MODELS: set[str] = {"auto", "baichuan", "chatglm", "dolly_v2", "falcon", "flan_t5", "gpt_neox", "llama", "mpt", "opt", "stablelm", "starcoder"}
|
||||
if t.TYPE_CHECKING: from . import auto as auto,baichuan as baichuan,chatglm as chatglm,dolly_v2 as dolly_v2,falcon as falcon,flan_t5 as flan_t5,gpt_neox as gpt_neox,llama as llama,mpt as mpt,opt as opt,stablelm as stablelm,starcoder as starcoder
|
||||
__lazy=LazyModule(__name__, os.path.abspath("__file__"), {k: [] for k in _MODELS})
|
||||
import typing as t
|
||||
from openllm_core.utils import LazyModule
|
||||
_MODELS:set[str]={"auto", "baichuan", "chatglm", "dolly_v2", "falcon", "flan_t5", "gpt_neox", "llama", "mpt", "opt", "stablelm", "starcoder"}
|
||||
if t.TYPE_CHECKING:from . import auto as auto,baichuan as baichuan,chatglm as chatglm,dolly_v2 as dolly_v2,falcon as falcon,flan_t5 as flan_t5,gpt_neox as gpt_neox,llama as llama,mpt as mpt,opt as opt,stablelm as stablelm,starcoder as starcoder
|
||||
__lazy=LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS})
|
||||
__all__=__lazy.__all__
|
||||
__dir__=__lazy.__dir__
|
||||
__getattr__=__lazy.__getattr__
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, os
|
||||
import openllm
|
||||
from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vllm_available
|
||||
from openllm_core.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vllm_available
|
||||
from openllm_core.config import AutoConfig as AutoConfig, CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_auto": ["AutoConfig", "CONFIG_MAPPING", "CONFIG_MAPPING_NAMES"], "modeling_auto": ["MODEL_MAPPING_NAMES"], "modeling_flax_auto": ["MODEL_FLAX_MAPPING_NAMES"], "modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"], "modeling_vllm_auto": ["MODEL_VLLM_MAPPING_NAMES"]}
|
||||
_import_structure: dict[str, list[str]] = {"modeling_auto": ["MODEL_MAPPING_NAMES"], "modeling_flax_auto": ["MODEL_FLAX_MAPPING_NAMES"], "modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"], "modeling_vllm_auto": ["MODEL_VLLM_MAPPING_NAMES"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING as CONFIG_MAPPING,
|
||||
CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES,
|
||||
AutoConfig as AutoConfig,
|
||||
)
|
||||
from .modeling_auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
|
||||
from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
|
||||
from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
import inflection, openllm
|
||||
from openllm.utils import ReprMixin
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
from openllm._typing_compat import LiteralString
|
||||
from collections import _odict_items, _odict_keys, _odict_values
|
||||
ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]]
|
||||
ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]]
|
||||
ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]]
|
||||
|
||||
# NOTE: This is the entrypoint when adding new model config
|
||||
CONFIG_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLMConfig"), ("dolly_v2", "DollyV2Config"), ("falcon", "FalconConfig"), ("flan_t5", "FlanT5Config"), ("gpt_neox", "GPTNeoXConfig"), ("llama", "LlamaConfig"), ("mpt", "MPTConfig"), ("opt", "OPTConfig"), ("stablelm", "StableLMConfig"), ("starcoder", "StarCoderConfig"), ("baichuan", "BaichuanConfig")])
|
||||
|
||||
class _LazyConfigMapping(OrderedDict, ReprMixin):
|
||||
def __init__(self, mapping: OrderedDict[LiteralString, LiteralString]):
|
||||
self._mapping = mapping
|
||||
self._extra_content: dict[str, t.Any] = {}
|
||||
self._modules: dict[str, types.ModuleType] = {}
|
||||
def __getitem__(self, key: str) -> t.Any:
|
||||
if key in self._extra_content: return self._extra_content[key]
|
||||
if key not in self._mapping:
|
||||
if inflection.underscore(key) in self._mapping: return self.__getitem__(inflection.underscore(key))
|
||||
raise KeyError(key)
|
||||
value, module_name = self._mapping[key], inflection.underscore(key)
|
||||
if module_name not in self._modules: self._modules[module_name] = openllm.utils.EnvVarMixin(module_name).module
|
||||
if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value)
|
||||
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the object at the top level.
|
||||
return getattr(openllm, value)
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]: return set(self._mapping.keys())
|
||||
def __repr__(self) -> str: return ReprMixin.__repr__(self)
|
||||
def __repr_args__(self) -> t.Generator[tuple[str, t.Any], t.Any, t.Any]: yield from self._mapping.items()
|
||||
def keys(self) -> ConfigKeysView: return t.cast("ConfigKeysView", list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||
def values(self) -> ConfigValuesView: return t.cast("ConfigValuesView", [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()))
|
||||
def items(self) -> ConfigItemsView: return t.cast("ConfigItemsView", [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()))
|
||||
def __iter__(self) -> t.Iterator[str]: return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||
def __contains__(self, item: t.Any) -> bool: return item in self._mapping or item in self._extra_content
|
||||
def register(self, key: str, value: t.Any) -> None:
|
||||
if key in self._mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.")
|
||||
self._extra_content[key] = value
|
||||
|
||||
CONFIG_MAPPING: dict[str, type[openllm.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||
# The below handle special alias when we call underscore to the name directly without processing camelcase first.
|
||||
CONFIG_NAME_ALIASES: dict[str, str] = {"chat_glm": "chatglm", "stable_lm": "stablelm", "star_coder": "starcoder", "gpt_neo_x": "gpt_neox",}
|
||||
|
||||
class AutoConfig:
|
||||
def __init__(self, *_: t.Any, **__: t.Any): raise EnvironmentError("Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.")
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig:
|
||||
model_name = inflection.underscore(model_name)
|
||||
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs)
|
||||
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
|
||||
@classmethod
|
||||
def infer_class_from_name(cls, name: str) -> type[openllm.LLMConfig]:
|
||||
model_name = inflection.underscore(name)
|
||||
if model_name in CONFIG_NAME_ALIASES: model_name = CONFIG_NAME_ALIASES[model_name]
|
||||
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name]
|
||||
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
|
||||
@@ -3,10 +3,10 @@ from __future__ import annotations
|
||||
import importlib, inspect, logging, typing as t
|
||||
from collections import OrderedDict
|
||||
import inflection, openllm
|
||||
from openllm.utils import ReprMixin
|
||||
from openllm_core.utils import ReprMixin
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._typing_compat import LiteralString, LLMRunner
|
||||
from openllm_core._typing_compat import LiteralString, LLMRunner
|
||||
import types
|
||||
from collections import _odict_items, _odict_keys, _odict_values
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
|
||||
MODEL_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLM"), ("dolly_v2", "DollyV2"), ("falcon", "Falcon"), ("flan_t5", "FlanT5"), ("gpt_neox", "GPTNeoX"), ("llama", "Llama"), ("mpt", "MPT"), ("opt", "OPT"), ("stablelm", "StableLM"), ("starcoder", "StarCoder"), ("baichuan", "Baichuan")])
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
|
||||
MODEL_FLAX_MAPPING_NAMES = OrderedDict([("flan_t5", "FlaxFlanT5"), ("opt", "FlaxOPT")])
|
||||
MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
|
||||
MODEL_TF_MAPPING_NAMES = OrderedDict([("flan_t5", "TFFlanT5"), ("opt", "TFOPT")])
|
||||
MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
|
||||
MODEL_VLLM_MAPPING_NAMES = OrderedDict([("baichuan", "VLLMBaichuan"), ("dolly_v2", "VLLMDollyV2"), ("falcon", "VLLMFalcon"), ("gpt_neox", "VLLMGPTNeoX"), ("mpt", "VLLMMPT"), ("opt", "VLLMOPT"), ("stablelm", "VLLMStableLM"), ("starcoder", "VLLMStarCoder"), ("llama", "VLLMLlama")])
|
||||
MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES)
|
||||
|
||||
@@ -2,14 +2,13 @@ from __future__ import annotations
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_cpm_kernels_available, is_torch_available, is_vllm_available
|
||||
from openllm_core.config.configuration_baichuan import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_BAICHUAN_COMMAND_DOCSTRING as START_BAICHUAN_COMMAND_DOCSTRING,
|
||||
BaichuanConfig as BaichuanConfig,
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_baichuan": ["BaichuanConfig", "START_BAICHUAN_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_baichuan import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_BAICHUAN_COMMAND_DOCSTRING as START_BAICHUAN_COMMAND_DOCSTRING,
|
||||
BaichuanConfig as BaichuanConfig,
|
||||
)
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class BaichuanConfig(openllm.LLMConfig):
|
||||
"""Baichuan-7B is an open-source, large-scale pre-trained language model developed by Baichuan Intelligent Technology.
|
||||
|
||||
Baichuan-7B is based on Transformer architecture,
|
||||
which contains 7 billion parameters and trained on approximately 1.2 trillion tokens.
|
||||
It supports both Chinese and English languages with a context window length of 4096.
|
||||
It has achieved the best performance among models of the same size on standard Chinese
|
||||
and English benchmarks (C-Eval, MMLU, etc).
|
||||
Refer to [Baichuan-7B's GitHub page](https://github.com/baichuan-inc/Baichuan-7B) for more information.
|
||||
"""
|
||||
__config__ = {"name_type": "lowercase", "trust_remote_code": True, "timeout": 3600000, "requires_gpu": True, "url": "https://github.com/baichuan-inc/Baichuan-7B", "requirements": ["cpm-kernels", "sentencepiece"], "architecture": "BaiChuanForCausalLM",
|
||||
"default_id": "baichuan-inc/baichuan-7b", "model_ids": ["baichuan-inc/baichuan-7b", "baichuan-inc/baichuan-13b-base", "baichuan-inc/baichuan-13b-chat", "fireballoon/baichuan-vicuna-chinese-7b", "fireballoon/baichuan-vicuna-7b", "hiyouga/baichuan-7b-sft"]}
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 2048
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
|
||||
START_BAICHUAN_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for Baichuan model.
|
||||
|
||||
\b
|
||||
> See more information about Baichuan at [baichuan-inc/Baichuan-7B](https://github.com/baichuan-inc/Baichuan-7B)
|
||||
|
||||
\b
|
||||
## Usage
|
||||
|
||||
Currently, Baichuan only supports PyTorch. Make sure ``torch`` is available in your system.
|
||||
|
||||
\b
|
||||
Baichuan Runner will use baichuan-inc/Baichuan-7B as the default model. To change to any other
|
||||
saved pretrained Baichuan, provide ``OPENLLM_Baichuan_MODEL_ID='fireballoon/baichuan-vicuna-chinese-7b'``
|
||||
or provide `--model-id` flag when running ``openllm start baichuan``:
|
||||
|
||||
\b
|
||||
$ openllm start baichuan --model-id='fireballoon/baichuan-vicuna-chinese-7b'
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
@@ -1,16 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class Baichuan(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_p: float | None = None, temperature: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature, **attrs}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): # type: ignore[attr-defined]
|
||||
outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMBaichuan(openllm.LLM["vllm.LLMEngine", "transformers.PreTrainedTokenizerBase"]):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_p: float | None = None, temperature: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature, **attrs}, {}
|
||||
|
||||
@@ -2,14 +2,13 @@ from __future__ import annotations
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_cpm_kernels_available, is_torch_available
|
||||
from openllm_core.config.configuration_chatglm import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_CHATGLM_COMMAND_DOCSTRING as START_CHATGLM_COMMAND_DOCSTRING,
|
||||
ChatGLMConfig as ChatGLMConfig,
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_chatglm": ["ChatGLMConfig", "START_CHATGLM_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_chatglm import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_CHATGLM_COMMAND_DOCSTRING as START_CHATGLM_COMMAND_DOCSTRING,
|
||||
ChatGLMConfig as ChatGLMConfig,
|
||||
)
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class ChatGLMConfig(openllm.LLMConfig):
|
||||
"""ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
|
||||
|
||||
With the quantization technique, users can deploy locally on consumer-grade graphics cards
|
||||
(only 6GB of GPU memory is required at the INT4 quantization level).
|
||||
|
||||
ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue.
|
||||
The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning,
|
||||
feedback bootstrap, and reinforcement learning wit human feedback.
|
||||
With only about 6.2 billion parameters, the model is able to generate answers that are in line
|
||||
with human preference.
|
||||
|
||||
Refer to [ChatGLM's GitHub page](https://github.com/THUDM/ChatGLM-6B) for more information.
|
||||
"""
|
||||
__config__ = {"name_type": "lowercase", "trust_remote_code": True, "timeout": 3600000, "requires_gpu": True, "url": "https://github.com/THUDM/ChatGLM-6B", "requirements": ["cpm-kernels", "sentencepiece"], "architecture": "ChatGLMForConditionalGeneration",
|
||||
"default_id": "thudm/chatglm-6b", "model_ids": ["thudm/chatglm-6b", "thudm/chatglm-6b-int8", "thudm/chatglm-6b-int4", "thudm/chatglm2-6b", "thudm/chatglm2-6b-int4"]}
|
||||
retain_history: bool = openllm.LLMConfig.Field(False, description="Whether to retain history given to the model. If set to True, then the model will retain given history.")
|
||||
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 2048
|
||||
num_beams: int = 1
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
|
||||
START_CHATGLM_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for ChatGLM model.
|
||||
|
||||
\b
|
||||
> See more information about ChatGLM at [THUDM/ChatGLM-6b](https://huggingface.co/thudm/chatglm-6b)
|
||||
|
||||
\b
|
||||
## Usage
|
||||
|
||||
Currently, ChatGLM only supports PyTorch. Make sure ``torch`` is available in your system.
|
||||
|
||||
\b
|
||||
ChatGLM Runner will use THUDM/ChatGLM-6b as the default model. To change to any other ChatGLM
|
||||
saved pretrained, or a fine-tune ChatGLM, provide ``OPENLLM_CHATGLM_MODEL_ID='thudm/chatglm-6b-int8'``
|
||||
or provide `--model-id` flag when running ``openllm start chatglm``:
|
||||
|
||||
\b
|
||||
$ openllm start chatglm --model-id='thudm/chatglm-6b-int8'
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
@@ -1,32 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
|
||||
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
|
||||
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, num_beams: int | None = None, top_p: float | None = None, temperature: float | None = None, chat_history: list[tuple[str, str]] | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
prompt_text = ""
|
||||
if use_default_prompt_template and chat_history is not None:
|
||||
for i, (old_query, response) in enumerate(chat_history): prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n"
|
||||
prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:"
|
||||
else: prompt_text = prompt
|
||||
postprocess_generate_kwargs = {"chat_history": chat_history if chat_history is not None else None}
|
||||
return prompt_text, {"max_new_tokens": max_new_tokens, "num_beams": num_beams, "top_p": top_p, "temperature": temperature, **attrs}, postprocess_generate_kwargs
|
||||
def postprocess_generate(self, prompt: str, generation_result: tuple[str, list[tuple[str, str]]], *, chat_history: list[tuple[str, str]] | None = None, **attrs: t.Any) -> str:
|
||||
generated, history = generation_result
|
||||
if self.config.retain_history:
|
||||
if chat_history is None: raise ValueError("'retain_history' is True while there is no history provided.")
|
||||
chat_history.extend(history)
|
||||
return generated
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> tuple[str, list[tuple[str, str]]]:
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
self.model.eval()
|
||||
# Only use half precision if the model is not yet quantized
|
||||
if self.config.use_half_precision: self.model.half()
|
||||
return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
def embeddings(self, prompts: list[str]) -> openllm.LLMEmbeddings:
|
||||
import torch, torch.nn.functional as F
|
||||
embeddings: list[list[float]] = []
|
||||
num_tokens = 0
|
||||
for prompt in prompts:
|
||||
|
||||
@@ -2,14 +2,13 @@ from __future__ import annotations
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
from openllm_core.config.configuration_dolly_v2 import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_DOLLY_V2_COMMAND_DOCSTRING as START_DOLLY_V2_COMMAND_DOCSTRING,
|
||||
DollyV2Config as DollyV2Config,
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_dolly_v2": ["DollyV2Config", "START_DOLLY_V2_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_dolly_v2 import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_DOLLY_V2_COMMAND_DOCSTRING as START_DOLLY_V2_COMMAND_DOCSTRING,
|
||||
DollyV2Config as DollyV2Config,
|
||||
)
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class DollyV2Config(openllm.LLMConfig):
|
||||
"""Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use.
|
||||
|
||||
Based on pythia-12b, Dolly is trained on ~15k instruction/response fine tuning records databricks-dolly-15k
|
||||
generated by Databricks employees in capability domains from the InstructGPT paper, including brainstorming,
|
||||
classification, closed QA, generation, information extraction, open QA and summarization.
|
||||
|
||||
dolly-v2-12b is not a state-of-the-art model, but does exhibit surprisingly high quality instruction
|
||||
following behavior not characteristic of the foundation model on which it is based.
|
||||
|
||||
Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information.
|
||||
"""
|
||||
__config__ = {"timeout": 3600000, "url": "https://github.com/databrickslabs/dolly", "architecture": "GPTNeoXForCausalLM",
|
||||
"default_id": "databricks/dolly-v2-3b", "model_ids": ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"]}
|
||||
return_full_text: bool = openllm.LLMConfig.Field(False, description="Whether to return the full prompt to the users.")
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
top_p: float = 0.92
|
||||
top_k: int = 5
|
||||
max_new_tokens: int = 256
|
||||
eos_token_id: int = 50277 # NOTE: from get_special_token_id(self.tokenizer, END_KEY)
|
||||
|
||||
START_DOLLY_V2_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for dolly-v2 model.
|
||||
|
||||
\b
|
||||
> See more information about dolly-v2 at [databricks/dolly-v2-3b](https://huggingface.co/databricks/dolly-v2-3b)
|
||||
|
||||
\b
|
||||
## Usage
|
||||
|
||||
Currently, dolly-v2 only supports PyTorch. Make sure ``torch`` is available in your system.
|
||||
|
||||
\b
|
||||
Dolly-v2 Runner will use databricks/dolly-v2-3b as the default model. To change to any other dolly-v2
|
||||
saved pretrained, or a fine-tune dolly-v2, provide ``OPENLLM_DOLLY_V2_MODEL_ID='databricks/dolly-v2-7b'``
|
||||
or provide `--model-id` flag when running ``openllm start dolly-v2``:
|
||||
|
||||
\b
|
||||
$ openllm start dolly-v2 --model-id databricks/dolly-v2-7b
|
||||
"""
|
||||
INSTRUCTION_KEY = "### Instruction:"
|
||||
RESPONSE_KEY = "### Response:"
|
||||
END_KEY = "### End"
|
||||
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||
# NOTE: This is the prompt that is used for generating responses using an already
|
||||
# trained model. It ends with the response key, where the job of the model is to provide
|
||||
# the completion that follows it (i.e. the response itself).
|
||||
DEFAULT_PROMPT_TEMPLATE = """{intro}
|
||||
{instruction_key}
|
||||
{instruction}
|
||||
{response_key}
|
||||
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY)
|
||||
def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) -> int:
|
||||
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
||||
|
||||
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
|
||||
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
|
||||
|
||||
Args:
|
||||
tokenizer: the tokenizer
|
||||
key: the key to convert to a single token
|
||||
|
||||
Raises:
|
||||
RuntimeError: if more than one ID was generated
|
||||
|
||||
Returns:
|
||||
int: the token ID for the given key.
|
||||
"""
|
||||
token_ids = tokenizer.encode(key)
|
||||
if len(token_ids) > 1: raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
|
||||
return token_ids[0]
|
||||
@@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import logging, re, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from openllm._typing_compat import overload
|
||||
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE, END_KEY, RESPONSE_KEY, get_special_token_id
|
||||
from openllm_core._typing_compat import overload
|
||||
from openllm_core.config.configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE, END_KEY, RESPONSE_KEY, get_special_token_id
|
||||
|
||||
if t.TYPE_CHECKING: import torch, transformers, tensorflow as tf
|
||||
else: torch, transformers, tf = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("tf", globals(), "tensorflow")
|
||||
@@ -102,8 +101,6 @@ class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedToken
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.bfloat16}, {}
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline: return get_pipeline(transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), self.tokenizer, _init=True, return_full_text=self.config.return_full_text)
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any) -> str: return generation_result[0]["generated_text"]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
with torch.inference_mode(): return self.model(prompt, return_full_text=llm_config.return_full_text, generation_config=llm_config.to_generation_config())
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMDollyV2(openllm.LLM["vllm.LLMEngine", "transformers.PreTrainedTokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs}, {}
|
||||
|
||||
@@ -2,14 +2,13 @@ from __future__ import annotations
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
from openllm_core.config.configuration_falcon import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_FALCON_COMMAND_DOCSTRING as START_FALCON_COMMAND_DOCSTRING,
|
||||
FalconConfig as FalconConfig,
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_falcon": ["FalconConfig", "START_FALCON_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_falcon import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_FALCON_COMMAND_DOCSTRING as START_FALCON_COMMAND_DOCSTRING,
|
||||
FalconConfig as FalconConfig,
|
||||
)
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class FalconConfig(openllm.LLMConfig):
|
||||
"""Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora.
|
||||
|
||||
It is made available under the TII Falcon LLM License.
|
||||
|
||||
Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information.
|
||||
"""
|
||||
__config__ = {"name_type": "lowercase", "trust_remote_code": True, "requires_gpu": True, "timeout": int(36e6), "url": "https://falconllm.tii.ae/", "requirements": ["einops", "xformers"], "architecture": "FalconForCausalLM",
|
||||
"default_id": "tiiuae/falcon-7b", "model_ids": ["tiiuae/falcon-7b", "tiiuae/falcon-40b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct"],
|
||||
"fine_tune_strategies": ({"adapter_type": "lora", "r": 64, "lora_alpha": 16, "lora_dropout": 0.1, "bias": "none", "target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]},)}
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 200
|
||||
top_k: int = 10
|
||||
num_return_sequences: int = 1
|
||||
num_beams: int = 4
|
||||
early_stopping: bool = True
|
||||
|
||||
START_FALCON_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for FalconLM model.
|
||||
|
||||
\b
|
||||
> See more information about falcon at [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b)
|
||||
|
||||
\b
|
||||
## Usage
|
||||
|
||||
By default, this model will use the PyTorch model for inference. However, this model also support vLLM.
|
||||
|
||||
Note that if you use vLLM, a NVIDIA GPU is required.
|
||||
|
||||
\b
|
||||
FalconLM Runner will use tiiuae/falcon-7b as the default model. To change to any other FalconLM
|
||||
saved pretrained, or a fine-tune FalconLM, provide ``OPENLLM_FALCON_MODEL_ID='tiiuae/falcon-7b-instruct'``
|
||||
or provide `--model-id` flag when running ``openllm start falcon``:
|
||||
|
||||
\b
|
||||
$ openllm start falcon --model-id tiiuae/falcon-7b-instruct
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{context}
|
||||
{user_name}: {instruction}
|
||||
{agent}:
|
||||
"""
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
@@ -9,8 +7,6 @@ class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTraine
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, eos_token_id: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): # type: ignore[attr-defined]
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMFalcon(openllm.LLM["vllm.LLMEngine", "transformers.PreTrainedTokenizerBase"]):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, eos_token_id: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs}, {}
|
||||
|
||||
@@ -2,14 +2,13 @@ from __future__ import annotations
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available
|
||||
from openllm_core.config.configuration_flan_t5 import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_FLAN_T5_COMMAND_DOCSTRING as START_FLAN_T5_COMMAND_DOCSTRING,
|
||||
FlanT5Config as FlanT5Config,
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_flan_t5": ["FlanT5Config", "START_FLAN_T5_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_flan_t5 import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_FLAN_T5_COMMAND_DOCSTRING as START_FLAN_T5_COMMAND_DOCSTRING,
|
||||
FlanT5Config as FlanT5Config,
|
||||
)
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class FlanT5Config(openllm.LLMConfig):
|
||||
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf).
|
||||
|
||||
It is an enhanced version of T5 that has been finetuned in a mixture of tasks.
|
||||
|
||||
Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information.
|
||||
"""
|
||||
__config__ = {"url": "https://huggingface.co/docs/transformers/model_doc/flan-t5", "architecture": "T5ForConditionalGeneration", "model_type": "seq2seq_lm",
|
||||
"default_id": "google/flan-t5-large", "model_ids": ["google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl",]}
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 2048
|
||||
top_k: int = 50
|
||||
top_p: float = 0.4
|
||||
repetition_penalty = 1.0
|
||||
|
||||
START_FLAN_T5_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for FLAN-T5 model.
|
||||
|
||||
\b
|
||||
> See more information about FLAN-T5 at [huggingface/transformers](https://huggingface.co/docs/transformers/model_doc/flan-t5)
|
||||
|
||||
\b
|
||||
## Usage
|
||||
|
||||
By default, this model will use the PyTorch model for inference. However, this model supports both Flax and Tensorflow.
|
||||
|
||||
\b
|
||||
- To use Flax, set the environment variable ``OPENLLM_FLAN_T5_FRAMEWORK="flax"``
|
||||
|
||||
\b
|
||||
- To use Tensorflow, set the environment variable ``OPENLLM_FLAN_T5_FRAMEWORK="tf"``
|
||||
|
||||
\b
|
||||
FLAN-T5 Runner will use google/flan-t5-large as the default model. To change to any other FLAN-T5
|
||||
saved pretrained, or a fine-tune FLAN-T5, provide ``OPENLLM_FLAN_T5_MODEL_ID='google/flan-t5-xxl'``
|
||||
or provide `--model-id` flag when running ``openllm start flan-t5``:
|
||||
|
||||
\b
|
||||
$ openllm start flan-t5 --model-id google/flan-t5-xxl
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """Answer the following question:\nQuestion: {instruction}\nAnswer:"""
|
||||
@@ -1,17 +1,14 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
|
||||
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class FlanT5(openllm.LLM["transformers.T5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
def embeddings(self, prompts: list[str]) -> openllm.LLMEmbeddings:
|
||||
import torch, torch.nn.functional as F
|
||||
embeddings: list[list[float]] = []
|
||||
num_tokens = 0
|
||||
for prompt in prompts:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.config.configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
@@ -9,7 +9,6 @@ class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "tra
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, decoder_start_token_id: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if decoder_start_token_id is None: decoder_start_token_id = 0
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, "decoder_start_token_id": decoder_start_token_id}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
# NOTE: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main as it is required for encoder-decoder generation.
|
||||
decoder_start_token_id = attrs.pop("decoder_start_token_id", 0)
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="tf").input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
|
||||
@@ -2,14 +2,13 @@ from __future__ import annotations
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
from openllm_core.config.configuration_gpt_neox import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_GPT_NEOX_COMMAND_DOCSTRING as START_GPT_NEOX_COMMAND_DOCSTRING,
|
||||
GPTNeoXConfig as GPTNeoXConfig,
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_gpt_neox": ["GPTNeoXConfig", "START_GPT_NEOX_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_gpt_neox import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_GPT_NEOX_COMMAND_DOCSTRING as START_GPT_NEOX_COMMAND_DOCSTRING,
|
||||
GPTNeoXConfig as GPTNeoXConfig,
|
||||
)
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class GPTNeoXConfig(openllm.LLMConfig):
|
||||
"""GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license.
|
||||
|
||||
It is, to the best of our knowledge, the largest dense autoregressive model
|
||||
that has publicly available weights at the time of submission. The training and evaluation code, as well as the model weights,
|
||||
can be found at https://github.com/EleutherAI/gpt-neox.
|
||||
|
||||
GPTNeoX has been used to fine-tune on various models, such as Dolly, StableLM, and Pythia.
|
||||
|
||||
Note that OpenLLM provides first-class support for all of the aforementioned model. Users can
|
||||
also use `openllm start gpt-neox` to run all of the GPTNeoX variant's model
|
||||
|
||||
Refer to [GPTNeoX's model card](https://huggingface.co/docs/transformers/model_doc/gpt_neox)
|
||||
for more information.
|
||||
"""
|
||||
__config__ = {"model_name": "gpt_neox", "start_name": "gpt-neox", "requires_gpu": True, "architecture": "GPTNeoXForCausalLM", "url": "https://github.com/EleutherAI/gpt-neox",
|
||||
"default_id": "eleutherai/gpt-neox-20b", "model_ids": ["eleutherai/gpt-neox-20b"]}
|
||||
use_half_precision: bool = openllm.LLMConfig.Field(True, description="Whether to use half precision for model.")
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 100
|
||||
|
||||
START_GPT_NEOX_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for GPTNeoX model.
|
||||
|
||||
\b
|
||||
> See more information about GPTNeoX at [HuggingFace's model card](https://huggingface.co/docs/transformers/model_doc/gpt_neox)
|
||||
|
||||
\b
|
||||
## Usage
|
||||
|
||||
Currently, GPTNeoX only supports PyTorch. Make sure ``torch`` is available in your system.
|
||||
|
||||
\b
|
||||
GPTNeoX Runner will use EleutherAI/gpt-neox-20b as the default model. To change to any other GPTNeoX
|
||||
saved pretrained, or a fine-tune GPTNeoX, provide ``OPENLLM_GPT_NEOX_MODEL_ID='stabilityai/stablelm-tuned-alpha-3b'``
|
||||
or provide `--model-id` flag when running ``openllm start gpt-neox``:
|
||||
|
||||
\b
|
||||
$ openllm start gpt-neox --model-id 'stabilityai/stablelm-tuned-alpha-3b'
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
@@ -1,20 +1,19 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {}
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM:
|
||||
import transformers
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs)
|
||||
if self.config.use_half_precision: model.half()
|
||||
return model
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=openllm.StoppingCriteriaList([openllm.StopOnTokens()])))
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm, logging
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMGPTNeoX(openllm.LLM["vllm.LLMEngine", "transformers.GPTNeoXTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {}
|
||||
|
||||
@@ -2,15 +2,14 @@ from __future__ import annotations
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
from openllm_core.config.configuration_llama import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
PROMPT_MAPPING as PROMPT_MAPPING,
|
||||
START_LLAMA_COMMAND_DOCSTRING as START_LLAMA_COMMAND_DOCSTRING,
|
||||
LlamaConfig as LlamaConfig,
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_llama": ["LlamaConfig", "START_LLAMA_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE", "PROMPT_MAPPING"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_llama import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
PROMPT_MAPPING as PROMPT_MAPPING,
|
||||
START_LLAMA_COMMAND_DOCSTRING as START_LLAMA_COMMAND_DOCSTRING,
|
||||
LlamaConfig as LlamaConfig,
|
||||
)
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
|
||||
class LlamaConfig(openllm.LLMConfig):
|
||||
"""LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
|
||||
|
||||
It is a collection of foundation language models ranging from 7B to 65B parameters.
|
||||
|
||||
Llama also include support for the recent propsed [Llama-2](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/)
|
||||
|
||||
Note that all variants of Llama including fine-tuning, quantisation format are all supported with ``openllm.Llama``.
|
||||
|
||||
Refer to [Llama's model card](https://huggingface.co/docs/transformers/main/model_doc/llama)
|
||||
for more information.
|
||||
"""
|
||||
use_llama2_prompt: bool = openllm.LLMConfig.Field(False, description="Whether to use the prompt format for Llama 2. Disable this when working with Llama 1.")
|
||||
__config__ = {"name_type": "lowercase", "url": "https://github.com/facebookresearch/llama", "default_implementation": {"cpu": "pt", "nvidia.com/gpu": "pt"}, "architecture": "LlamaForCausalLM", "requirements": ["fairscale", "sentencepiece"], "tokenizer_class": "LlamaTokenizerFast",
|
||||
"default_id": "NousResearch/llama-2-7b-hf", "model_ids": ["meta-llama/Llama-2-70b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-70b-hf", "meta-llama/Llama-2-13b-hf",
|
||||
"meta-llama/Llama-2-7b-hf", "NousResearch/llama-2-70b-chat-hf", "NousResearch/llama-2-13b-chat-hf", "NousResearch/llama-2-7b-chat-hf", "NousResearch/llama-2-70b-hf", "NousResearch/llama-2-13b-hf", "NousResearch/llama-2-7b-hf",
|
||||
"openlm-research/open_llama_7b_v2", "openlm-research/open_llama_3b_v2", "openlm-research/open_llama_13b", "huggyllama/llama-65b", "huggyllama/llama-30b", "huggyllama/llama-13b", "huggyllama/llama-7b"],
|
||||
"fine_tune_strategies": ({"adapter_type": "lora", "r": 64, "lora_alpha": 16, "lora_dropout": 0.1, "bias": "none"},)}
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 128
|
||||
temperature: float = 0.6
|
||||
top_p: float = 0.9
|
||||
top_k: int = 12
|
||||
class SamplingParams:
|
||||
best_of: int = 1
|
||||
presence_penalty: float = 0.5
|
||||
|
||||
START_LLAMA_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for Llama model.
|
||||
|
||||
\b
|
||||
> See more information about Llama at [Llama's model card](https://huggingface.co/docs/transformers/main/model_doc/llama
|
||||
|
||||
\b
|
||||
## Usage
|
||||
|
||||
By default, this model will use [vLLM](https://github.com/vllm-project/vllm) for inference.
|
||||
This model will also supports PyTorch.
|
||||
|
||||
\b
|
||||
- To use PyTorch, set the environment variable ``OPENLLM_LLAMA_FRAMEWORK="pt"``
|
||||
|
||||
\b
|
||||
Llama Runner will use decapoda-research/llama-7b-hf as the default model. To change to any other Llama
|
||||
saved pretrained, or a fine-tune Llama, provide ``OPENLLM_LLAMA_MODEL_ID='openlm-research/open_llama_7b_v2'``
|
||||
or provide `--model-id` flag when running ``openllm start llama``:
|
||||
|
||||
\b
|
||||
$ openllm start llama --model-id 'openlm-research/open_llama_7b_v2'
|
||||
|
||||
\b
|
||||
OpenLLM also supports running Llama-2 and its fine-tune and variants. To import the Llama weights, one can use the following:
|
||||
|
||||
\b
|
||||
$ CONVERTER=hf-llama2 openllm import llama /path/to/llama-2
|
||||
"""
|
||||
SYSTEM_MESSAGE = """
|
||||
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
||||
|
||||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
|
||||
"""
|
||||
SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = "[INST]", "[/INST]", "<<SYS>>", "</s>", "<s>"
|
||||
# TODO: support history and v1 prompt implementation
|
||||
_v1_prompt, _v2_prompt = """{instruction}""", """{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key} """.format(start_key=SINST_KEY, sys_key=SYS_KEY, system_message=SYSTEM_MESSAGE, instruction="{instruction}", end_key=EINST_KEY)
|
||||
PROMPT_MAPPING = {"v1": _v1_prompt, "v2": _v2_prompt}
|
||||
def _get_prompt(model_type: t.Literal["v1", "v2"]) -> str: return PROMPT_MAPPING[model_type]
|
||||
DEFAULT_PROMPT_TEMPLATE = _get_prompt
|
||||
@@ -1,17 +1,14 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_llama import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F
|
||||
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class Llama(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def embeddings(self, prompts: list[str]) -> openllm.LLMEmbeddings:
|
||||
import torch, torch.nn.functional as F
|
||||
encoding = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.device)
|
||||
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
|
||||
with torch.inference_mode():
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_llama import DEFAULT_PROMPT_TEMPLATE
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMLlama(openllm.LLM["vllm.LLMEngine", "transformers.LlamaTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {}
|
||||
|
||||
@@ -2,15 +2,14 @@ from __future__ import annotations
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
from openllm_core.config.configuration_mpt import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
PROMPT_MAPPING as PROMPT_MAPPING,
|
||||
START_MPT_COMMAND_DOCSTRING as START_MPT_COMMAND_DOCSTRING,
|
||||
MPTConfig as MPTConfig,
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_mpt": ["MPTConfig", "START_MPT_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE", "PROMPT_MAPPING"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_mpt import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
PROMPT_MAPPING as PROMPT_MAPPING,
|
||||
START_MPT_COMMAND_DOCSTRING as START_MPT_COMMAND_DOCSTRING,
|
||||
MPTConfig as MPTConfig,
|
||||
)
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys, typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING: MPTPromptType = t.Literal["default", "instruct", "chat", "storywriter"]
|
||||
else: MPTPromptType = str
|
||||
|
||||
class MPTConfig(openllm.LLMConfig):
|
||||
"""MPT is a decoder-style transformer pretrained from scratch on English text and code.
|
||||
|
||||
This model was trained by [MosaicML](https://www.mosaicml.com/).
|
||||
|
||||
``openllm.MPT`` encapsulate a family of MPT variants that is publicly available
|
||||
on HuggingFace. Refers [HuggingFace's MosaicML page](https://huggingface.co/mosaicml)
|
||||
for more details on specific models.
|
||||
"""
|
||||
__config__ = {"name_type": "lowercase", "trust_remote_code": True, "url": "https://huggingface.co/mosaicml", "timeout": int(36e6), "requirements": ["triton", "einops"], "architecture": "MPTForCausalLM",
|
||||
"default_id": "mosaicml/mpt-7b-instruct", "model_ids": ["mosaicml/mpt-7b", "mosaicml/mpt-7b-instruct", "mosaicml/mpt-7b-chat", "mosaicml/mpt-7b-storywriter", "mosaicml/mpt-30b", "mosaicml/mpt-30b-instruct", "mosaicml/mpt-30b-chat"]}
|
||||
prompt_type: MPTPromptType = openllm.LLMConfig.Field('"default"', description="Given prompt type for running MPT. Default will be inferred from model name if pretrained.")
|
||||
max_sequence_length: int = openllm.LLMConfig.Field(2048, description="Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)")
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 128
|
||||
temperature: float = 0
|
||||
top_p: float = 0.8
|
||||
|
||||
START_MPT_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for MPT model.
|
||||
|
||||
\b
|
||||
> See more information about MPT at [HuggingFace's MosaicML page](https://huggingface.co/mosaicml)
|
||||
|
||||
\b
|
||||
## Usage
|
||||
|
||||
Currently, MPT only supports PyTorch. Make sure ``torch`` is available in your system.
|
||||
|
||||
If you want to use Flash Attention support with openai/triton, make sure to install OpenLLM with
|
||||
|
||||
\b
|
||||
```bash
|
||||
pip install "openllm[mpt]"
|
||||
```
|
||||
|
||||
\b
|
||||
MPT Runner will use mosaicml/mpt-7b-instruct as the default model. To change to any other MPT
|
||||
saved pretrained, or a fine-tune MPT, provide ``OPENLLM_MPT_MODEL_ID='mosaicml/mpt-30b'``
|
||||
or provide `--model-id` flag when running ``openllm start mpt``:
|
||||
|
||||
\b
|
||||
$ openllm start mpt --model-id mosaicml/mpt-30b
|
||||
"""
|
||||
INSTRUCTION_KEY, RESPONSE_KEY, END_KEY = "### Instruction:", "### Response:", "### End"
|
||||
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||
# NOTE: This is the prompt that is used for generating responses using an already
|
||||
# trained model. It ends with the response key, where the job of the model is to provide
|
||||
# the completion that follows it (i.e. the response itself).
|
||||
_chat_prompt, _default_prompt, _instruct_prompt = """{instruction}""", """{instruction}""", """{intro}
|
||||
{instruction_key}
|
||||
{instruction}
|
||||
{response_key}
|
||||
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY)
|
||||
PROMPT_MAPPING = {"default": _default_prompt, "instruct": _instruct_prompt, "storywriter": _default_prompt, "chat": _chat_prompt}
|
||||
def _get_prompt(model_type: str) -> str: return PROMPT_MAPPING[model_type]
|
||||
DEFAULT_PROMPT_TEMPLATE = _get_prompt
|
||||
@@ -1,14 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, bentoml, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from openllm.utils import generate_labels, is_triton_available
|
||||
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE, MPTPromptType
|
||||
|
||||
if t.TYPE_CHECKING: import transformers, torch
|
||||
else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
def get_mpt_config(model_id_or_path: str, max_sequence_length: int, device: torch.device | str | int | None, device_map: str | None = None, trust_remote_code: bool = True) -> transformers.PretrainedConfig:
|
||||
import torch
|
||||
config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
|
||||
if hasattr(config, "init_device") and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device)
|
||||
if hasattr(config, "attn_config") and is_triton_available(): config.attn_config["attn_impl"] = "triton"
|
||||
@@ -18,10 +15,15 @@ def get_mpt_config(model_id_or_path: str, max_sequence_length: int, device: torc
|
||||
return config
|
||||
class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def llm_post_init(self) -> None: self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
def llm_post_init(self) -> None:
|
||||
import torch
|
||||
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
|
||||
import torch, transformers
|
||||
_, tokenizer_attrs = self.llm_parameters
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
device_map = attrs.pop("device_map", None)
|
||||
@@ -33,6 +35,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
|
||||
try: return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
finally: torch.cuda.empty_cache()
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.PreTrainedModel:
|
||||
import transformers
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
device_map = attrs.pop("device_map", None)
|
||||
trust_remote_code = attrs.pop("trust_remote_code", True)
|
||||
@@ -40,18 +43,8 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, config=config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, device_map=device_map, **attrs)
|
||||
model.tie_weights()
|
||||
return model
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, prompt_type: MPTPromptType | None = None, use_default_prompt_template: bool = True, **attrs: t.Any,) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
_template = None
|
||||
if use_default_prompt_template:
|
||||
if prompt_type is None:
|
||||
if "instruct" in self.model_id: prompt_type = "instruct"
|
||||
elif "storywriter" in self.model_id: prompt_type = "storywriter"
|
||||
elif "chat" in self.model_id: prompt_type = "chat"
|
||||
else: prompt_type = "default"
|
||||
_template = DEFAULT_PROMPT_TEMPLATE(prompt_type)
|
||||
return process_prompt(prompt, _template, use_default_prompt_template), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
attrs = {"do_sample": False if llm_config["temperature"] == 0 else True, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, "generation_config": llm_config.to_generation_config()}
|
||||
|
||||
@@ -1,20 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE, MPTPromptType
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers, vllm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMMPT(openllm.LLM["vllm.LLMEngine", "transformers.GPTNeoXTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, prompt_type: MPTPromptType | None = None, use_default_prompt_template: bool = True, **attrs: t.Any,) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
_template = None
|
||||
if use_default_prompt_template:
|
||||
if prompt_type is None:
|
||||
if "instruct" in self.model_id: prompt_type = "instruct"
|
||||
elif "storywriter" in self.model_id: prompt_type = "storywriter"
|
||||
elif "chat" in self.model_id: prompt_type = "chat"
|
||||
else: prompt_type = "default"
|
||||
_template = DEFAULT_PROMPT_TEMPLATE(prompt_type)
|
||||
return process_prompt(prompt, _template, use_default_prompt_template), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p}, {}
|
||||
|
||||
@@ -2,14 +2,13 @@ from __future__ import annotations
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vllm_available
|
||||
from openllm_core.config.configuration_opt import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_OPT_COMMAND_DOCSTRING as START_OPT_COMMAND_DOCSTRING,
|
||||
OPTConfig as OPTConfig,
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_opt": ["OPTConfig", "START_OPT_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_opt import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_OPT_COMMAND_DOCSTRING as START_OPT_COMMAND_DOCSTRING,
|
||||
OPTConfig as OPTConfig,
|
||||
)
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class OPTConfig(openllm.LLMConfig):
|
||||
"""OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI.
|
||||
|
||||
OPT was predominantly pretrained with English text, but a small amount of non-English data is still present
|
||||
within the training corpus via CommonCrawl. The model was pretrained using a causal language modeling (CLM)
|
||||
objective. OPT belongs to the same family of decoder-only models like GPT-3. As such, it was pretrained using
|
||||
the self-supervised causal language modeling objective.
|
||||
|
||||
Refer to [OPT's HuggingFace page](https://huggingface.co/docs/transformers/model_doc/opt) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase", "trust_remote_code": False, "url": "https://huggingface.co/docs/transformers/model_doc/opt",
|
||||
"default_id": "facebook/opt-1.3b", "architecture": "OPTForCausalLM", "model_ids": ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b", "facebook/opt-66b"],
|
||||
"fine_tune_strategies": ({"adapter_type": "lora", "r": 16, "lora_alpha": 32, "target_modules": ["q_proj", "v_proj"], "lora_dropout": 0.05, "bias": "none"},)
|
||||
}
|
||||
format_outputs: bool = openllm.LLMConfig.Field(False, description="""Whether to format the outputs. This can be used when num_return_sequences > 1.""")
|
||||
class GenerationConfig:
|
||||
top_k: int = 15
|
||||
temperature: float = 0.75
|
||||
max_new_tokens: int = 1024
|
||||
num_return_sequences: int = 1
|
||||
|
||||
START_OPT_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for OPT model.
|
||||
|
||||
\b
|
||||
> See more information about falcon at [facebook/opt-66b](https://huggingface.co/facebook/opt-66b)
|
||||
|
||||
\b
|
||||
## Usage
|
||||
|
||||
By default, this model will use the PyTorch model for inference. However, this model supports both Flax and Tensorflow.
|
||||
|
||||
\b
|
||||
- To use Flax, set the environment variable ``OPENLLM_OPT_FRAMEWORK="flax"``
|
||||
|
||||
\b
|
||||
- To use Tensorflow, set the environment variable ``OPENLLM_OPT_FRAMEWORK="tf"``
|
||||
|
||||
\b
|
||||
OPT Runner will use facebook/opt-2.7b as the default model. To change to any other OPT
|
||||
saved pretrained, or a fine-tune OPT, provide ``OPENLLM_OPT_MODEL_ID='facebook/opt-6.7b'``
|
||||
or provide `--model-id` flag when running ``openllm start opt``:
|
||||
|
||||
\b
|
||||
$ openllm start opt --model-id facebook/opt-6.7b
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import logging, typing as t, bentoml, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from openllm.utils import generate_labels
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
@@ -14,8 +14,4 @@ class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tok
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
return bentoml.transformers.save_model(self.tag, transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences, "repetition_penalty": repetition_penalty}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
if len(generation_result) == 1: return generation_result[0]
|
||||
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
|
||||
else: return "\n".join(generation_result)
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="np"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()).sequences, skip_special_tokens=True)
|
||||
|
||||
@@ -1,19 +1,14 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
if len(generation_result) == 1: return generation_result[0]
|
||||
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
|
||||
else: return "\n".join(generation_result)
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
|
||||
@@ -1,21 +1,12 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, bentoml, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from openllm.utils import generate_labels
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
import typing as t, bentoml, openllm
|
||||
from openllm_core.utils import generate_labels
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
import transformers
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
return bentoml.transformers.save_model(self.tag, transformers.TFOPTForCausalLM.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, num_return_sequences: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
if len(generation_result) == 1: return generation_result[0]
|
||||
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
|
||||
else: return "\n".join(generation_result)
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]: return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="tf"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
import typing as t, openllm
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMOPT(openllm.LLM["vllm.LLMEngine", "transformers.GPT2Tokenizer"]):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
|
||||
@@ -2,14 +2,13 @@ from __future__ import annotations
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
from openllm_core.config.configuration_stablelm import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_STABLELM_COMMAND_DOCSTRING as START_STABLELM_COMMAND_DOCSTRING,
|
||||
StableLMConfig as StableLMConfig,
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_stablelm": ["StableLMConfig", "START_STABLELM_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_stablelm import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_STABLELM_COMMAND_DOCSTRING as START_STABLELM_COMMAND_DOCSTRING,
|
||||
StableLMConfig as StableLMConfig,
|
||||
)
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
class StableLMConfig(openllm.LLMConfig):
|
||||
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models.
|
||||
|
||||
It is pre-trained on a diverse collection of English datasets with a sequence
|
||||
length of 4096 to push beyond the context window limitations of existing open-source language models.
|
||||
|
||||
StableLM-Tuned-Alpha is a suite of 3B and 7B parameter decoder-only language models
|
||||
built on top of the StableLM-Base-Alpha models and further fine-tuned on various chat and
|
||||
instruction-following datasets.
|
||||
|
||||
Refer to [StableLM-tuned's model card](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b)
|
||||
and [StableLM-base's model card](https://huggingface.co/stabilityai/stablelm-base-alpha-7b)
|
||||
for more information.
|
||||
"""
|
||||
__config__ = {"name_type": "lowercase", "url": "https://github.com/Stability-AI/StableLM", "architecture": "GPTNeoXForCausalLM",
|
||||
"default_id": "stabilityai/stablelm-tuned-alpha-3b", "model_ids": ["stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-tuned-alpha-7b", "stabilityai/stablelm-base-alpha-3b", "stabilityai/stablelm-base-alpha-7b"]}
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 128
|
||||
top_k: int = 0
|
||||
top_p: float = 0.9
|
||||
|
||||
START_STABLELM_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for StableLM model.
|
||||
|
||||
\b
|
||||
> See more information about StableLM at [stabilityai/stablelm-base-alpha-3b](https://huggingface.co/stabilityai/stablelm-base-alpha-3b)
|
||||
|
||||
\b
|
||||
## Usage
|
||||
|
||||
Currently, StableLM only supports PyTorch. Make sure ``torch`` is available in your system.
|
||||
|
||||
\b
|
||||
StableLM Runner will use stabilityai/stablelm-base-alpha-3b as the default model. To change to any other StableLM
|
||||
saved pretrained, or a fine-tune StableLM, provide ``OPENLLM_STABLELM_MODEL_ID='stabilityai/stablelm-tuned-alpha-3b'``
|
||||
or provide `--model-id` flag when running ``openllm start stablelm``:
|
||||
|
||||
\b
|
||||
$ openllm start stablelm --model-id 'stabilityai/stablelm-tuned-alpha-3b'
|
||||
"""
|
||||
SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>"""
|
||||
@@ -1,23 +1,15 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE, SYSTEM_PROMPT
|
||||
|
||||
if t.TYPE_CHECKING: import transformers, torch
|
||||
else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
def llm_post_init(self) -> None: self.bettertransformer = True if not torch.cuda.is_available() else False
|
||||
def llm_post_init(self) -> None:
|
||||
import torch
|
||||
self.bettertransformer = True if not torch.cuda.is_available() else False
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if "tuned" in self._model_id and use_default_prompt_template:
|
||||
system_prompt = attrs.pop("system_prompt", SYSTEM_PROMPT)
|
||||
prompt_text = process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs)
|
||||
else: prompt_text = prompt
|
||||
return prompt_text, {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
with torch.inference_mode(): return [self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=openllm.StoppingCriteriaList([openllm.StopOnTokens()]))[0], skip_special_tokens=True)]
|
||||
|
||||
@@ -1,16 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE, SYSTEM_PROMPT
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMStableLM(openllm.LLM["vllm.LLMEngine", "transformers.GPTNeoXTokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if "tuned" in self._model_id and use_default_prompt_template:
|
||||
system_prompt = attrs.pop("system_prompt", SYSTEM_PROMPT)
|
||||
prompt_text = process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs)
|
||||
else: prompt_text = prompt
|
||||
return prompt_text, {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p}, {}
|
||||
|
||||
@@ -2,14 +2,13 @@ from __future__ import annotations
|
||||
import sys, typing as t
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule, is_torch_available, is_vllm_available
|
||||
from openllm_core.config.configuration_starcoder import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_STARCODER_COMMAND_DOCSTRING as START_STARCODER_COMMAND_DOCSTRING,
|
||||
StarCoderConfig as StarCoderConfig,
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {"configuration_starcoder": ["StarCoderConfig", "START_STARCODER_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"]}
|
||||
if t.TYPE_CHECKING:
|
||||
from .configuration_starcoder import (
|
||||
DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE,
|
||||
START_STARCODER_COMMAND_DOCSTRING as START_STARCODER_COMMAND_DOCSTRING,
|
||||
StarCoderConfig as StarCoderConfig,
|
||||
)
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError: pass
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm
|
||||
|
||||
class StarCoderConfig(openllm.LLMConfig):
|
||||
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
|
||||
|
||||
The model uses [Multi Query Attention](https://arxiv.org/abs/1911.02150),
|
||||
[a context window of 8192 tokens](https://arxiv.org/abs/2205.14135), and was trained using the
|
||||
[Fill-in-the-Middle](https://arxiv.org/abs/2207.14255) objective on 1 trillion tokens.
|
||||
|
||||
Refer to [StarCoder's model card](https://huggingface.co/bigcode/starcoder) for more information.
|
||||
"""
|
||||
__config__ = {"name_type": "lowercase", "requires_gpu": True, "url": "https://github.com/bigcode-project/starcoder", "architecture": "GPTBigCodeForCausalLM", "requirements": ["bitsandbytes"], "workers_per_resource": 0.5,
|
||||
"default_id": "bigcode/starcoder", "model_ids": ["bigcode/starcoder", "bigcode/starcoderbase"]}
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.2
|
||||
max_new_tokens: int = 256
|
||||
min_new_tokens: int = 32
|
||||
top_k: float = 50
|
||||
top_p: float = 0.95
|
||||
pad_token_id: int = 49152
|
||||
repetition_penalty: float = 1.2
|
||||
|
||||
START_STARCODER_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for StarCoder model.
|
||||
|
||||
\b
|
||||
> See more information about StarCoder at [bigcode/starcoder](https://huggingface.co/bigcode/starcoder)
|
||||
|
||||
\b
|
||||
## Usage
|
||||
|
||||
Currently, StarCoder only supports PyTorch. Make sure ``torch`` is available in your system.
|
||||
|
||||
\b
|
||||
StarCoder Runner will use bigcode/starcoder as the default model. To change to any other StarCoder
|
||||
saved pretrained, or a fine-tune StarCoder, provide ``OPENLLM_STARCODER_MODEL_ID='bigcode/starcoder'``
|
||||
or provide `--model-id` flag when running ``openllm start starcoder``:
|
||||
|
||||
\b
|
||||
$ openllm start starcoder --model-id 'bigcode/starcoder'
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, EOD, FIM_INDICATOR = "<fim-prefix>", "<fim-middle>", "<fim-suffix>", "<fim-pad>", "<|endoftext|>", "<FILL_HERE>"
|
||||
@@ -1,34 +1,24 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, bentoml, openllm
|
||||
from openllm.utils import generate_labels
|
||||
from .configuration_starcoder import EOD, FIM_INDICATOR, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from openllm_core.config.configuration_starcoder import EOD, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.GPT2TokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
import torch, transformers
|
||||
torch_dtype, device_map = attrs.pop("torch_dtype", torch.float16), attrs.pop("device_map", "auto")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], "pad_token": EOD})
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs)
|
||||
try: return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
finally: torch.cuda.empty_cache()
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, repetition_penalty: float | None = None, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None
|
||||
if fim_mode:
|
||||
try: prefix, suffix = prompt.split(FIM_INDICATOR)
|
||||
except Exception as err: raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err
|
||||
prompt_text = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
|
||||
else: prompt_text = prompt
|
||||
# XXX: This value for pad_token_id is currently a hack, need more investigate why the
|
||||
# default starcoder doesn't include the same value as santacoder EOD
|
||||
return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {}
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
# eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder
|
||||
# NOTE: support fine-tuning starcoder
|
||||
|
||||
@@ -1,19 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from .configuration_starcoder import EOD, FIM_INDICATOR, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMStarCoder(openllm.LLM["vllm.LLMEngine", "transformers.GPT2TokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, repetition_penalty: float | None = None, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None
|
||||
if fim_mode:
|
||||
try: prefix, suffix = prompt.split(FIM_INDICATOR)
|
||||
except Exception as err: raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err
|
||||
prompt_text = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
|
||||
else: prompt_text = prompt
|
||||
# XXX: This value for pad_token_id is currently a hack, need more investigate why the
|
||||
# default starcoder doesn't include the same value as santacoder EOD
|
||||
return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {}
|
||||
|
||||
@@ -26,7 +26,7 @@ from __future__ import annotations
|
||||
import importlib, typing as t
|
||||
import cloudpickle, fs, openllm
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
from openllm._typing_compat import M, T, ParamSpec
|
||||
from openllm_core._typing_compat import M, T, ParamSpec
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
|
||||
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
import typing as t
|
||||
import bentoml, openllm
|
||||
|
||||
if t.TYPE_CHECKING: from openllm._typing_compat import M
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import M
|
||||
|
||||
_conversion_strategy = {"pt": "ggml"}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ if t.TYPE_CHECKING:
|
||||
import torch.nn
|
||||
|
||||
from bentoml._internal.models import ModelStore
|
||||
from openllm._typing_compat import DictStrAny, M, T
|
||||
from openllm_core._typing_compat import DictStrAny, M, T
|
||||
else:
|
||||
vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm")
|
||||
autogptq = openllm.utils.LazyLoader("autogptq", globals(), "auto_gptq")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import copy, typing as t, openllm
|
||||
import copy, typing as t, openllm_core, openllm
|
||||
from bentoml._internal.models.model import ModelInfo, ModelSignature
|
||||
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING, HUB_ATTRS
|
||||
|
||||
@@ -7,8 +7,8 @@ if t.TYPE_CHECKING:
|
||||
import torch, transformers, bentoml
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
from bentoml._internal.models.model import ModelSignaturesType
|
||||
from openllm._typing_compat import DictStrAny, M, T
|
||||
else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
from openllm_core._typing_compat import DictStrAny, M, T
|
||||
else: transformers, torch = openllm_core.utils.LazyLoader("transformers", globals(), "transformers"), openllm_core.utils.LazyLoader("torch", globals(), "torch")
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
@@ -33,7 +33,7 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
|
||||
__cls = getattr(transformers, openllm.utils.first_not_none(__llm.config["tokenizer_class"], default="AutoTokenizer"), None)
|
||||
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config["tokenizer_class"], default="AutoTokenizer"), None)
|
||||
if __cls is None: raise ValueError(f"Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`")
|
||||
return __cls
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import typing as t, attr
|
||||
from huggingface_hub import HfApi
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
from openllm._typing_compat import M, T
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool: return any(s.rfilename.endswith(".safetensors") for s in HfApi().model_info(model_id, revision=revision).siblings)
|
||||
@attr.define(slots=True)
|
||||
|
||||
@@ -4,244 +4,19 @@ User can import these function for convenience, but
|
||||
we won't ensure backward compatibility for these functions. So use with caution.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import contextlib, functools, hashlib, logging, logging.config, os, sys, types, typing as t, openllm
|
||||
from pathlib import Path
|
||||
from circus.exc import ConflictError
|
||||
from bentoml._internal.configuration import (
|
||||
DEBUG_ENV_VAR as DEBUG_ENV_VAR,
|
||||
GRPC_DEBUG_ENV_VAR as _GRPC_DEBUG_ENV_VAR,
|
||||
QUIET_ENV_VAR as QUIET_ENV_VAR,
|
||||
get_debug_mode as _get_debug_mode,
|
||||
get_quiet_mode as _get_quiet_mode,
|
||||
set_quiet_mode as set_quiet_mode,
|
||||
)
|
||||
from bentoml._internal.models.model import ModelContext as _ModelContext
|
||||
from bentoml._internal.types import LazyType as LazyType
|
||||
from bentoml._internal.utils import (
|
||||
LazyLoader as LazyLoader,
|
||||
bentoml_cattr as bentoml_cattr,
|
||||
calc_dir_size as calc_dir_size,
|
||||
first_not_none as first_not_none,
|
||||
pkg as pkg,
|
||||
reserve_free_port as reserve_free_port,
|
||||
resolve_user_filepath as resolve_user_filepath,
|
||||
)
|
||||
from openllm.utils.lazy import (
|
||||
LazyModule as LazyModule,
|
||||
VersionInfo as VersionInfo,
|
||||
import typing as t, openllm_core
|
||||
from . import (
|
||||
dummy_flax_objects as dummy_flax_objects,
|
||||
dummy_pt_objects as dummy_pt_objects,
|
||||
dummy_tf_objects as dummy_tf_objects,
|
||||
dummy_vllm_objects as dummy_vllm_objects,
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._typing_compat import AnyCallable, LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try: from typing import GenericAlias as _TypingGenericAlias # type: ignore
|
||||
except ImportError: _TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,)
|
||||
else: _WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore # _GenericAlias is the actual GenericAlias implementation
|
||||
|
||||
DEV_DEBUG_VAR = "OPENLLMDEVDEBUG"
|
||||
|
||||
def set_debug_mode(enabled: bool, level: int = 1) -> None:
|
||||
# monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs
|
||||
if enabled: os.environ[DEV_DEBUG_VAR] = str(level)
|
||||
os.environ[DEBUG_ENV_VAR] = str(enabled)
|
||||
os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR"
|
||||
|
||||
def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.Any], ...] | None) -> bool:
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
if isinstance(cls, _WithArgsTypes): return False
|
||||
raise
|
||||
|
||||
def available_devices() -> tuple[str, ...]:
|
||||
"""Return available GPU under system. Currently only supports NVIDIA GPUs."""
|
||||
from openllm._strategies import NvidiaGpuResource
|
||||
return tuple(NvidiaGpuResource.from_system())
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def generate_hash_from_file(f: str, algorithm: t.Literal["md5", "sha1"] = "sha1") -> str:
|
||||
"""Generate a hash from given file's modification time.
|
||||
|
||||
Args:
|
||||
f: The file to generate the hash from.
|
||||
algorithm: The hashing algorithm to use. Defaults to 'sha1' (similar to how Git generate its commit hash.)
|
||||
|
||||
Returns:
|
||||
The generated hash.
|
||||
"""
|
||||
return getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def device_count() -> int: return len(available_devices())
|
||||
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def non_intrusive_setattr(obj: t.Any, name: str, value: t.Any) -> None:
|
||||
"""This makes sure that we don't overwrite any existing attributes on the object."""
|
||||
_setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj)
|
||||
if not hasattr(obj, name): _setattr(name, value)
|
||||
|
||||
def field_env_key(model_name: str, key: str, suffix: str | None = None) -> str: return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
|
||||
|
||||
# Special debug flag controled via OPENLLMDEVDEBUG
|
||||
DEBUG: bool = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.environ.get(DEV_DEBUG_VAR)))
|
||||
# MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins
|
||||
MYPY = False
|
||||
SHOW_CODEGEN: bool = DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3
|
||||
|
||||
def get_debug_mode() -> bool: return DEBUG or _get_debug_mode()
|
||||
def get_quiet_mode() -> bool: return not DEBUG and _get_quiet_mode()
|
||||
|
||||
class ExceptionFilter(logging.Filter):
|
||||
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
|
||||
"""A filter of all exception."""
|
||||
if exclude_exceptions is None: exclude_exceptions = [ConflictError]
|
||||
if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError)
|
||||
super(ExceptionFilter, self).__init__(**kwargs)
|
||||
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.exc_info:
|
||||
etype, _, _ = record.exc_info
|
||||
if etype is not None:
|
||||
for exc in self.EXCLUDE_EXCEPTIONS:
|
||||
if issubclass(etype, exc): return False
|
||||
return True
|
||||
|
||||
class InfoFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool: return logging.INFO <= record.levelno < logging.WARNING
|
||||
|
||||
_LOGGING_CONFIG: dict[str, t.Any] = {
|
||||
"version": 1, "disable_existing_loggers": True,
|
||||
"filters": {"excfilter": {"()": "openllm.utils.ExceptionFilter"}, "infofilter": {"()": "openllm.utils.InfoFilter"}},
|
||||
"handlers": {"bentomlhandler": {"class": "logging.StreamHandler", "filters": ["excfilter", "infofilter"], "stream": "ext://sys.stdout"}, "defaulthandler": {"class": "logging.StreamHandler", "level": logging.WARNING}},
|
||||
"loggers": {"bentoml": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False}, "openllm": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False,}},
|
||||
"root": {"level": logging.WARNING},
|
||||
}
|
||||
|
||||
def configure_logging() -> None:
|
||||
"""Configure logging for OpenLLM.
|
||||
|
||||
Behaves similar to how BentoML loggers are being configured.
|
||||
"""
|
||||
if get_quiet_mode():
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.ERROR
|
||||
elif get_debug_mode() or DEBUG:
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.DEBUG
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.DEBUG
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.DEBUG
|
||||
else:
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.INFO
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.INFO
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.INFO
|
||||
|
||||
logging.config.dictConfig(_LOGGING_CONFIG)
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def in_notebook() -> bool:
|
||||
try:
|
||||
from IPython.core.getipython import get_ipython
|
||||
if t.TYPE_CHECKING:
|
||||
from IPython.core.interactiveshell import InteractiveShell
|
||||
return "IPKernelApp" in t.cast("dict[str, t.Any]", t.cast(t.Callable[[], "InteractiveShell"], get_ipython)().config)
|
||||
except (ImportError, AttributeError): return False
|
||||
|
||||
_dockerenv, _cgroup = Path("/.dockerenv"), Path("/proc/self/cgroup")
|
||||
|
||||
class suppress(contextlib.suppress, contextlib.ContextDecorator):
|
||||
"""A version of contextlib.suppress with decorator support.
|
||||
|
||||
>>> @suppress(KeyError)
|
||||
... def key_error():
|
||||
... {}['']
|
||||
>>> key_error()
|
||||
"""
|
||||
|
||||
def compose(*funcs: AnyCallable) -> AnyCallable:
|
||||
"""Compose any number of unary functions into a single unary function.
|
||||
|
||||
>>> import textwrap
|
||||
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
|
||||
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
|
||||
>>> strip_and_dedent(compose.__doc__) == expected
|
||||
True
|
||||
|
||||
Compose also allows the innermost function to take arbitrary arguments.
|
||||
|
||||
>>> round_three = lambda x: round(x, ndigits=3)
|
||||
>>> f = compose(round_three, int.__truediv__)
|
||||
>>> [f(3*x, x+1) for x in range(1,10)]
|
||||
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
|
||||
"""
|
||||
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable: return lambda *args, **kwargs: f1(f2(*args, **kwargs))
|
||||
return functools.reduce(compose_two, funcs)
|
||||
|
||||
def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]:
|
||||
"""Decorate a function with a transform function that is invoked on results returned from the decorated function.
|
||||
|
||||
```python
|
||||
@apply(reversed)
|
||||
def get_numbers(start):
|
||||
"doc for get_numbers"
|
||||
return range(start, start+3)
|
||||
list(get_numbers(4))
|
||||
# [6, 5, 4]
|
||||
```
|
||||
```python
|
||||
get_numbers.__doc__
|
||||
# 'doc for get_numbers'
|
||||
```
|
||||
"""
|
||||
return lambda func: functools.wraps(func)(compose(transform, func))
|
||||
|
||||
@apply(bool)
|
||||
@suppress(FileNotFoundError)
|
||||
def _text_in_file(text: str, filename: Path) -> bool:
|
||||
return any(text in line for line in filename.open())
|
||||
|
||||
def in_docker() -> bool:
|
||||
"""Is this current environment running in docker?
|
||||
|
||||
```python
|
||||
type(in_docker())
|
||||
```
|
||||
"""
|
||||
return _dockerenv.exists() or _text_in_file("docker", _cgroup)
|
||||
|
||||
T, K = t.TypeVar("T"), t.TypeVar("K")
|
||||
|
||||
def resolve_filepath(path: str, ctx: str | None = None) -> str:
|
||||
"""Resolve a file path to an absolute path, expand user and environment variables."""
|
||||
try: return resolve_user_filepath(path, ctx)
|
||||
except FileNotFoundError: return path
|
||||
|
||||
def validate_is_path(maybe_path: str) -> bool: return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
|
||||
|
||||
def generate_context(framework_name: str) -> _ModelContext:
|
||||
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
|
||||
if openllm.utils.is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch")
|
||||
if openllm.utils.is_tf_available():
|
||||
from bentoml._internal.frameworks.utils.tensorflow import get_tf_version
|
||||
framework_versions["tensorflow"] = get_tf_version()
|
||||
if openllm.utils.is_flax_available(): framework_versions.update({"flax": pkg.get_pkg_version("flax"), "jax": pkg.get_pkg_version("jax"), "jaxlib": pkg.get_pkg_version("jaxlib")})
|
||||
return _ModelContext(framework_name=framework_name, framework_versions=framework_versions)
|
||||
from openllm_core._typing_compat import LiteralRuntime
|
||||
import openllm
|
||||
|
||||
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> dict[str, t.Any]: return {"runtime": llm.runtime, "framework": "openllm", "model_name": llm.config["model_name"], "architecture": llm.config["architecture"], "serialisation_format": llm._serialisation_format}
|
||||
|
||||
_TOKENIZER_PREFIX = "_tokenizer_"
|
||||
|
||||
def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
"""Normalize the given attrs to a model and tokenizer kwargs accordingly."""
|
||||
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX):]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
|
||||
for k in tuple(attrs.keys()):
|
||||
if k.startswith(_TOKENIZER_PREFIX): del attrs[k]
|
||||
return attrs, tokenizer_attrs
|
||||
|
||||
def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM | openllm.AutoTFLLM | openllm.AutoFlaxLLM | openllm.AutoVLLM]:
|
||||
import openllm
|
||||
if implementation == "tf": return openllm.AutoTFLLM
|
||||
@@ -250,62 +25,8 @@ def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM | o
|
||||
elif implementation == "vllm": return openllm.AutoVLLM
|
||||
else: raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf', 'vllm')")
|
||||
|
||||
# NOTE: The set marks contains a set of modules name
|
||||
# that are available above and are whitelisted
|
||||
# to be included in the extra_objects map.
|
||||
_whitelist_modules = {"pkg"}
|
||||
|
||||
# XXX: define all classes, functions import above this line
|
||||
# since _extras will be the locals() import from this file.
|
||||
_extras: dict[str, t.Any] = {k: v for k, v in locals().items() if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith("_"))}
|
||||
_extras["__openllm_migration__"] = {"ModelEnv": "EnvVarMixin"}
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"analytics": [], "codegen": [], "dantic": [], "dummy_flax_objects": [], "dummy_pt_objects": [], "dummy_tf_objects": [], "dummy_vllm_objects": [], "representation": ["ReprMixin"], "lazy": ["LazyModule"],
|
||||
"import_utils": ["OPTIONAL_DEPENDENCIES", "ENV_VARS_TRUE_VALUES", "DummyMetaclass", "EnvVarMixin", "require_backends",
|
||||
"is_cpm_kernels_available", "is_einops_available", "is_flax_available", "is_tf_available", "is_vllm_available", "is_torch_available", "is_bitsandbytes_available", "is_peft_available", "is_datasets_available",
|
||||
"is_transformers_supports_kbit", "is_transformers_supports_agent", "is_jupyter_available", "is_jupytext_available", "is_notebook_available", "is_triton_available", "is_autogptq_available", "is_sentencepiece_available",
|
||||
"is_xformers_available", "is_fairscale_available"]}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# NOTE: The following exports useful utils from bentoml
|
||||
from . import (
|
||||
analytics as analytics,
|
||||
codegen as codegen,
|
||||
dantic as dantic,
|
||||
dummy_flax_objects as dummy_flax_objects,
|
||||
dummy_pt_objects as dummy_pt_objects,
|
||||
dummy_tf_objects as dummy_tf_objects,
|
||||
dummy_vllm_objects as dummy_vllm_objects,
|
||||
)
|
||||
from .import_utils import (
|
||||
ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES,
|
||||
OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES,
|
||||
DummyMetaclass as DummyMetaclass,
|
||||
EnvVarMixin as EnvVarMixin,
|
||||
is_autogptq_available as is_autogptq_available,
|
||||
is_bitsandbytes_available as is_bitsandbytes_available,
|
||||
is_cpm_kernels_available as is_cpm_kernels_available,
|
||||
is_datasets_available as is_datasets_available,
|
||||
is_einops_available as is_einops_available,
|
||||
is_fairscale_available as is_fairscale_available,
|
||||
is_flax_available as is_flax_available,
|
||||
is_jupyter_available as is_jupyter_available,
|
||||
is_jupytext_available as is_jupytext_available,
|
||||
is_notebook_available as is_notebook_available,
|
||||
is_peft_available as is_peft_available,
|
||||
is_sentencepiece_available as is_sentencepiece_available,
|
||||
is_tf_available as is_tf_available,
|
||||
is_torch_available as is_torch_available,
|
||||
is_transformers_supports_agent as is_transformers_supports_agent,
|
||||
is_transformers_supports_kbit as is_transformers_supports_kbit,
|
||||
is_triton_available as is_triton_available,
|
||||
is_vllm_available as is_vllm_available,
|
||||
is_xformers_available as is_xformers_available,
|
||||
require_backends as require_backends,
|
||||
)
|
||||
from .representation import ReprMixin as ReprMixin
|
||||
|
||||
__lazy = LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects=_extras)
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
__all__ = ["generate_labels", "infer_auto_class", "dummy_flax_objects", "dummy_pt_objects", "dummy_tf_objects", "dummy_vllm_objects"]
|
||||
def __dir__() -> t.Sequence[str]: return sorted(__all__)
|
||||
def __getattr__(it: str) -> t.Any:
|
||||
if hasattr(openllm_core.utils, it): return getattr(openllm_core.utils, it)
|
||||
else: raise AttributeError(f"module {__name__} has no attribute {it}")
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Telemetry related for OpenLLM tracking.
|
||||
|
||||
Users can disable this with OPENLLM_DO_NOT_TRACK envvar.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import contextlib, functools, logging, os, re, typing as t, importlib.metadata
|
||||
import attr, openllm
|
||||
from bentoml._internal.utils import analytics as _internal_analytics
|
||||
from openllm._typing_compat import ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = t.TypeVar("T")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This variable is a proxy that will control BENTOML_DO_NOT_TRACK
|
||||
OPENLLM_DO_NOT_TRACK = "OPENLLM_DO_NOT_TRACK"
|
||||
DO_NOT_TRACK = os.environ.get(OPENLLM_DO_NOT_TRACK, str(False)).upper()
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def do_not_track() -> bool: return DO_NOT_TRACK in openllm.utils.ENV_VARS_TRUE_VALUES
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _usage_event_debugging() -> bool: return os.environ.get("__BENTOML_DEBUG_USAGE", str(False)).lower() == "true"
|
||||
|
||||
def silent(func: t.Callable[P, T]) -> t.Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
try: return func(*args, **kwargs)
|
||||
except Exception as err:
|
||||
if _usage_event_debugging():
|
||||
if openllm.utils.get_debug_mode(): logger.error("Tracking Error: %s", err, stack_info=True, stacklevel=3)
|
||||
else: logger.info("Tracking Error: %s", err)
|
||||
else: logger.debug("Tracking Error: %s", err)
|
||||
return wrapper
|
||||
|
||||
@silent
|
||||
def track(event_properties: attr.AttrsInstance) -> None:
|
||||
if do_not_track(): return
|
||||
_internal_analytics.track(t.cast("_internal_analytics.schemas.EventMeta", event_properties))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_bentoml_tracking() -> t.Generator[None, None, None]:
|
||||
original_value = os.environ.pop(_internal_analytics.BENTOML_DO_NOT_TRACK, str(False))
|
||||
try:
|
||||
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = str(do_not_track())
|
||||
yield
|
||||
finally: os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = original_value
|
||||
|
||||
class EventMeta:
|
||||
@property
|
||||
def event_name(self) -> str:
|
||||
# camel case to snake case
|
||||
event_name = re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()
|
||||
# remove "_event" suffix
|
||||
suffix_to_remove = "_event"
|
||||
if event_name.endswith(suffix_to_remove): event_name = event_name[:-len(suffix_to_remove)]
|
||||
return event_name
|
||||
|
||||
@attr.define
|
||||
class ModelSaveEvent(EventMeta):
|
||||
module: str
|
||||
model_size_in_kb: float
|
||||
@attr.define
|
||||
class OpenllmCliEvent(EventMeta):
|
||||
cmd_group: str
|
||||
cmd_name: str
|
||||
openllm_version: str = importlib.metadata.version("openllm")
|
||||
# NOTE: reserved for the do_not_track logics
|
||||
duration_in_ms: t.Any = attr.field(default=None)
|
||||
error_type: str = attr.field(default=None)
|
||||
return_code: int = attr.field(default=None)
|
||||
@attr.define
|
||||
class StartInitEvent(EventMeta):
|
||||
model_name: str
|
||||
llm_config: t.Dict[str, t.Any] = attr.field(default=None)
|
||||
@staticmethod
|
||||
def handler(llm_config: openllm.LLMConfig) -> StartInitEvent: return StartInitEvent(model_name=llm_config["model_name"], llm_config=llm_config.model_dump())
|
||||
|
||||
def track_start_init(llm_config: openllm.LLMConfig) -> None:
|
||||
if do_not_track(): return
|
||||
track(StartInitEvent.handler(llm_config))
|
||||
@@ -1,141 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import functools, inspect, linecache, os, logging, string, types, typing as t
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
import orjson
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from fs.base import FS
|
||||
|
||||
import openllm
|
||||
from openllm._typing_compat import LiteralString, AnyCallable, DictStrAny, ListStr
|
||||
PartialAny = functools.partial[t.Any]
|
||||
|
||||
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
|
||||
logger = logging.getLogger(__name__)
|
||||
OPENLLM_MODEL_NAME = "# openllm: model name"
|
||||
OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map"
|
||||
class ModelNameFormatter(string.Formatter):
|
||||
model_keyword: LiteralString = "__model_name__"
|
||||
def __init__(self, model_name: str):
|
||||
"""The formatter that extends model_name to be formatted the 'service.py'."""
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.Any: return super().vformat(format_string, (), {self.model_keyword: self.model_name})
|
||||
def can_format(self, value: str) -> bool:
|
||||
try:
|
||||
self.parse(value)
|
||||
return True
|
||||
except ValueError: return False
|
||||
class ModelIdFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = "__model_id__"
|
||||
class ModelAdapterMapFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = "__model_adapter_map__"
|
||||
|
||||
_service_file = Path(os.path.abspath(__file__)).parent.parent/"_service.py"
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
|
||||
from openllm.utils import DEBUG
|
||||
model_name = llm.config["model_name"]
|
||||
logger.debug("Generating service file for %s at %s (dir=%s)", model_name, llm.config["service_name"], llm_fs.getsyspath("/"))
|
||||
with open(_service_file.__fspath__(), "r") as f: src_contents = f.readlines()
|
||||
for it in src_contents:
|
||||
if OPENLLM_MODEL_NAME in it: src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + "\n")
|
||||
elif OPENLLM_MODEL_ADAPTER_MAP in it: src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + "\n")
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents)
|
||||
if DEBUG: logger.info("Generated script:\n%s", script)
|
||||
llm_fs.writetext(llm.config["service_name"], script)
|
||||
|
||||
# sentinel object for unequivocal object() getattr
|
||||
_sentinel = object()
|
||||
def has_own_attribute(cls: type[t.Any], attrib_name: t.Any) -> bool:
|
||||
"""Check whether *cls* defines *attrib_name* (and doesn't just inherit it)."""
|
||||
attr = getattr(cls, attrib_name, _sentinel)
|
||||
if attr is _sentinel: return False
|
||||
for base_cls in cls.__mro__[1:]:
|
||||
a = getattr(base_cls, attrib_name, None)
|
||||
if attr is a: return False
|
||||
return True
|
||||
def get_annotations(cls: type[t.Any]) -> DictStrAny:
|
||||
if has_own_attribute(cls, "__annotations__"): return cls.__annotations__
|
||||
return t.cast("DictStrAny", {})
|
||||
|
||||
def is_class_var(annot: str | t.Any) -> bool:
|
||||
annot = str(annot)
|
||||
# Annotation can be quoted.
|
||||
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')): annot = annot[1:-1]
|
||||
return annot.startswith(("typing.ClassVar", "t.ClassVar", "ClassVar", "typing_extensions.ClassVar",))
|
||||
def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
|
||||
try: method_or_cls.__module__ = cls.__module__
|
||||
except AttributeError: pass
|
||||
try: method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}"
|
||||
except AttributeError: pass
|
||||
try: method_or_cls.__doc__ = _overwrite_doc or "Generated by ``openllm.LLMConfig`` for class " f"{cls.__qualname__}."
|
||||
except AttributeError: pass
|
||||
return method_or_cls
|
||||
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = "") -> None: eval(compile(script, filename, "exec"), globs, locs) # noqa: S307
|
||||
def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable:
|
||||
locs: DictStrAny = {}
|
||||
# In order of debuggers like PDB being able to step through the code, we add a fake linecache entry.
|
||||
count = 1
|
||||
base_filename = filename
|
||||
while True:
|
||||
linecache_tuple = (len(script), None, script.splitlines(True), filename)
|
||||
old_val = linecache.cache.setdefault(filename, linecache_tuple)
|
||||
if old_val == linecache_tuple: break
|
||||
else:
|
||||
filename = f"{base_filename[:-1]}-{count}>"
|
||||
count += 1
|
||||
_compile_and_eval(script, globs, locs, filename)
|
||||
return locs[name]
|
||||
|
||||
def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.Any]:
|
||||
"""Create a tuple subclass to hold class attributes.
|
||||
|
||||
The subclass is a bare tuple with properties for names.
|
||||
|
||||
class MyClassAttributes(tuple):
|
||||
__slots__ = ()
|
||||
x = property(itemgetter(0))
|
||||
"""
|
||||
from . import SHOW_CODEGEN
|
||||
|
||||
attr_class_name = f"{cls_name}Attributes"
|
||||
attr_class_template = [f"class {attr_class_name}(tuple):", " __slots__ = ()",]
|
||||
if attr_names:
|
||||
for i, attr_name in enumerate(attr_names): attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
|
||||
else: attr_class_template.append(" pass")
|
||||
globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property}
|
||||
if SHOW_CODEGEN: logger.info("Generated class for %s:\n\n%s", attr_class_name, "\n".join(attr_class_template))
|
||||
_compile_and_eval("\n".join(attr_class_template), globs)
|
||||
return globs[attr_class_name]
|
||||
|
||||
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None) -> AnyCallable:
|
||||
from openllm.utils import SHOW_CODEGEN
|
||||
script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass")
|
||||
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
|
||||
if annotations: meth.__annotations__ = annotations
|
||||
if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script)
|
||||
return meth
|
||||
|
||||
def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable:
|
||||
from openllm.utils import dantic, field_env_key
|
||||
def identity(_: str, x_value: t.Any) -> t.Any: return x_value
|
||||
default_callback = identity if default_callback is None else default_callback
|
||||
globs = {} if globs is None else globs
|
||||
globs.update({"__populate_env": dantic.env_converter, "__default_callback": default_callback, "__field_env": field_env_key, "__suffix": suffix or "", "__model_name": model_name,})
|
||||
lines: ListStr = ["__env = lambda field_name: __field_env(__model_name, field_name, __suffix)", "return [", " f.evolve(", " default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),", " metadata={", " 'env': f.metadata.get('env', __env(f.name)),", " 'description': f.metadata.get('description', '(not provided)'),", " },", " )", " for f in fields", "]"]
|
||||
fields_ann = "list[attr.Attribute[t.Any]]"
|
||||
return generate_function(cls, "__auto_env", lines, args=("_", "fields"), globs=globs, annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann})
|
||||
def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
|
||||
"""Enhance sdk with nice repr that plays well with your brain."""
|
||||
from openllm.utils import ReprMixin
|
||||
if name is None: name = func.__name__.strip("_")
|
||||
_signatures = inspect.signature(func).parameters
|
||||
def _repr(self: ReprMixin) -> str: return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
|
||||
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]: return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
|
||||
if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}"
|
||||
else: doc = func.__doc__
|
||||
return t.cast(_T, functools.update_wrapper(types.new_class(name, (t.cast("PartialAny", functools.partial), ReprMixin), exec_body=lambda ns: ns.update({"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), "__repr_args__": _repr_args, "__repr__": _repr, "__doc__": inspect.cleandoc(doc), "__module__": "openllm",}),)(func, **attrs), func,))
|
||||
|
||||
__all__ = ["gen_sdk", "make_attr_tuple_class", "make_env_transformer", "generate_unique_filename", "generate_function", "OPENLLM_MODEL_NAME", "OPENLLM_MODEL_ADAPTER_MAP"]
|
||||
@@ -1,387 +0,0 @@
|
||||
"""An interface provides the best of pydantic and attrs."""
|
||||
from __future__ import annotations
|
||||
import functools, importlib, os, sys, typing as t
|
||||
from enum import Enum
|
||||
import attr, click, click_option_group as cog, inflection, orjson
|
||||
from click import (
|
||||
ParamType,
|
||||
shell_completion as sc,
|
||||
types as click_types,
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING: from attr import _ValidatorType
|
||||
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar("FC", bound=t.Union[AnyCallable, click.Command])
|
||||
|
||||
__all__ = ["FC", "attrs_to_options", "Field", "parse_type", "is_typing", "is_literal", "ModuleType", "EnumChoice", "LiteralChoice", "allows_multiple", "is_mapping", "is_container", "parse_container_args", "parse_single_arg", "CUDA", "JsonType", "BytesType"]
|
||||
def __dir__() -> list[str]: return sorted(__all__)
|
||||
|
||||
def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: t.Any | None = None, suffix_generation: bool = False, suffix_sampling: bool = False,) -> t.Callable[[FC], FC]:
|
||||
# TODO: support parsing nested attrs class and Union
|
||||
envvar = field.metadata["env"]
|
||||
dasherized = inflection.dasherize(name)
|
||||
underscored = inflection.underscore(name)
|
||||
|
||||
if typ in (None, attr.NOTHING):
|
||||
typ = field.type
|
||||
if typ is None: raise RuntimeError(f"Failed to parse type for {name}")
|
||||
|
||||
full_option_name = f"--{dasherized}"
|
||||
if field.type is bool: full_option_name += f"/--no-{dasherized}"
|
||||
if suffix_generation: identifier = f"{model_name}_generation_{underscored}"
|
||||
elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}"
|
||||
else: identifier = f"{model_name}_{underscored}"
|
||||
|
||||
return cog.optgroup.option(identifier, full_option_name, type=parse_type(typ), required=field.default is attr.NOTHING, default=field.default if field.default not in (attr.NOTHING, None) else None, show_default=True, multiple=allows_multiple(typ) if typ else False, help=field.metadata.get("description", "(No description provided)"), show_envvar=True, envvar=envvar,)
|
||||
|
||||
def env_converter(value: t.Any, env: str | None = None) -> t.Any:
|
||||
if env is not None:
|
||||
value = os.environ.get(env, value)
|
||||
if value is not None and isinstance(value, str):
|
||||
try: return orjson.loads(value.lower())
|
||||
except orjson.JSONDecodeError as err: raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
|
||||
return value
|
||||
|
||||
def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[t.Any] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any) -> t.Any:
|
||||
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
|
||||
|
||||
By default, if both validator and ge are provided, then then ge will be
|
||||
piped into first, then all of the other validator will be run afterwards.
|
||||
|
||||
Args:
|
||||
default: The default value for ``dantic.Field``. Defaults to ``None``.
|
||||
ge: Greater than or equal to. Defaults to None.
|
||||
le: Less than or equal to. Defaults to None.
|
||||
validator: Optional attrs-compatible validators type. Default to None
|
||||
description: the documentation for the field. Defaults to None.
|
||||
env: the environment variable to read from. Defaults to None.
|
||||
auto_default: a bool indicating whether to use the default value as the environment.
|
||||
Defaults to False. If set to True, the behaviour of this Field will also depends
|
||||
on kw_only. If kw_only=True, the this field will become 'Required' and the default
|
||||
value is omitted. If kw_only=False, then the default value will be used as before.
|
||||
use_default_converter: a bool indicating whether to use the default converter. Defaults
|
||||
to True. If set to False, then the default converter will not be used.
|
||||
The default converter converts a given value from the environment variable
|
||||
for this given Field.
|
||||
**attrs: The rest of the arguments are passed to attr.field
|
||||
"""
|
||||
metadata = attrs.pop("metadata", {})
|
||||
if description is None: description = "(No description provided)"
|
||||
metadata["description"] = description
|
||||
if env is not None: metadata["env"] = env
|
||||
piped: list[_ValidatorType[t.Any]] = []
|
||||
|
||||
converter = attrs.pop("converter", None)
|
||||
if use_default_converter: converter = functools.partial(env_converter, env=env)
|
||||
|
||||
if ge is not None: piped.append(attr.validators.ge(ge))
|
||||
if le is not None: piped.append(attr.validators.le(le))
|
||||
if validator is not None: piped.append(validator)
|
||||
|
||||
if len(piped) == 0: _validator = None
|
||||
elif len(piped) == 1: _validator = piped[0]
|
||||
else: _validator = attr.validators.and_(*piped)
|
||||
|
||||
factory = attrs.pop("factory", None)
|
||||
if factory is not None and default is not None: raise RuntimeError("'factory' and 'default' are mutually exclusive.")
|
||||
# NOTE: the behaviour of this is we will respect factory over the default
|
||||
if factory is not None: attrs["factory"] = factory
|
||||
else: attrs["default"] = default
|
||||
|
||||
kw_only = attrs.pop("kw_only", False)
|
||||
if auto_default and kw_only:
|
||||
attrs.pop("default")
|
||||
|
||||
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
|
||||
|
||||
def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType, ...]:
|
||||
"""Transforms the pydantic field's type into a click-compatible type.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
ParamType: click type equivalent
|
||||
"""
|
||||
from . import lenient_issubclass
|
||||
|
||||
if t.get_origin(field_type) is t.Union:
|
||||
raise NotImplementedError("Unions are not supported")
|
||||
# enumeration strings or other Enum derivatives
|
||||
if lenient_issubclass(field_type, Enum):
|
||||
return EnumChoice(enum=field_type, case_sensitive=True)
|
||||
# literals are enum-like with way less functionality
|
||||
if is_literal(field_type):
|
||||
return LiteralChoice(value=field_type, case_sensitive=True)
|
||||
# modules, classes, functions
|
||||
if is_typing(field_type): return ModuleType()
|
||||
# entire dictionaries:
|
||||
# using a Dict, convert in advance
|
||||
if is_mapping(field_type): return JsonType()
|
||||
# list, List[p], Tuple[p], Set[p] and so on
|
||||
if is_container(field_type): return parse_container_args(field_type)
|
||||
# bytes are not natively supported by click
|
||||
if lenient_issubclass(field_type, bytes): return BytesType()
|
||||
# return the current type: it should be a primitive
|
||||
return field_type
|
||||
|
||||
def is_typing(field_type: type) -> bool:
|
||||
"""Checks whether the current type is a module-like type.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
bool: true if the type is itself a type
|
||||
"""
|
||||
raw = t.get_origin(field_type)
|
||||
if raw is None: return False
|
||||
if raw is type or raw is t.Type: return True
|
||||
return False
|
||||
|
||||
def is_literal(field_type: type) -> bool:
|
||||
"""Checks whether the given field type is a Literal type or not.
|
||||
|
||||
Literals are weird: isinstance and subclass do not work, so you compare
|
||||
the origin with the Literal declaration itself.
|
||||
|
||||
Args:
|
||||
field_type: current pydantic type
|
||||
|
||||
Returns:
|
||||
bool: true if Literal type, false otherwise
|
||||
"""
|
||||
origin = t.get_origin(field_type)
|
||||
return origin is not None and origin is t.Literal
|
||||
|
||||
class ModuleType(ParamType):
|
||||
name = "module"
|
||||
|
||||
def _import_object(self, value: str) -> t.Any:
|
||||
module_name, class_name = value.rsplit(".", maxsplit=1)
|
||||
if not all(s.isidentifier() for s in module_name.split(".")): raise ValueError(f"'{value}' is not a valid module name")
|
||||
if not class_name.isidentifier(): raise ValueError(f"Variable '{class_name}' is not a valid identifier")
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
if class_name:
|
||||
try: return getattr(module, class_name)
|
||||
except AttributeError: raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None
|
||||
|
||||
def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
try:
|
||||
if isinstance(value, str): return self._import_object(value)
|
||||
return value
|
||||
except Exception as exc: self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
|
||||
|
||||
class EnumChoice(click.Choice):
|
||||
name = "enum"
|
||||
|
||||
def __init__(self, enum: Enum, case_sensitive: bool = False):
|
||||
"""Enum type support for click that extends ``click.Choice``.
|
||||
|
||||
Args:
|
||||
enum: Given enum
|
||||
case_sensitive: Whether this choice should be case case_sensitive.
|
||||
"""
|
||||
self.mapping = enum
|
||||
self.internal_type = type(enum)
|
||||
choices: list[t.Any] = [e.name for e in enum.__class__]
|
||||
super().__init__(choices, case_sensitive)
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
|
||||
if isinstance(value, self.internal_type):
|
||||
return value
|
||||
result = super().convert(value, param, ctx)
|
||||
if isinstance(result, str):
|
||||
result = self.internal_type[result]
|
||||
return result
|
||||
|
||||
class LiteralChoice(EnumChoice):
|
||||
name = "literal"
|
||||
|
||||
def __init__(self, value: t.Any, case_sensitive: bool = False):
|
||||
"""Literal support for click."""
|
||||
# expect every literal value to belong to the same primitive type
|
||||
values = list(value.__args__)
|
||||
item_type = type(values[0])
|
||||
if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.")
|
||||
_mapping = {str(v): v for v in values}
|
||||
super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
|
||||
self.internal_type = item_type
|
||||
|
||||
def allows_multiple(field_type: type[t.Any]) -> bool:
|
||||
"""Checks whether the current type allows for multiple arguments to be provided as input or not.
|
||||
|
||||
For containers, it exploits click's support for lists and such to use the same option multiple times
|
||||
to create a complex object: `python run.py --subsets train --subsets test`
|
||||
# becomes `subsets: ["train", "test"]`.
|
||||
|
||||
Args:
|
||||
field_type: pydantic type.
|
||||
|
||||
Returns:
|
||||
bool: true if it's a composite field (lists, containers and so on), false otherwise
|
||||
"""
|
||||
# Early out for mappings, since it's better to deal with them using strings.
|
||||
if is_mapping(field_type):
|
||||
return False
|
||||
# Activate multiple option for (simple) container types
|
||||
if is_container(field_type):
|
||||
args = parse_container_args(field_type)
|
||||
# A non-composite type has a single argument, such as 'List[int]'
|
||||
# A composite type has a tuple of arguments, like 'Tuple[str, int, int]'.
|
||||
# For the moment, only non-composite types are allowed.
|
||||
return not isinstance(args, tuple)
|
||||
return False
|
||||
|
||||
def is_mapping(field_type: type) -> bool:
|
||||
"""Checks whether this field represents a dictionary or JSON object.
|
||||
|
||||
Args:
|
||||
field_type (type): pydantic type
|
||||
|
||||
Returns:
|
||||
bool: true when the field is a dict-like object, false otherwise.
|
||||
"""
|
||||
# Early out for standard containers.
|
||||
from . import lenient_issubclass
|
||||
if lenient_issubclass(field_type, t.Mapping): return True
|
||||
# for everything else or when the typing is more complex, check its origin
|
||||
origin = t.get_origin(field_type)
|
||||
if origin is None: return False
|
||||
return lenient_issubclass(origin, t.Mapping)
|
||||
|
||||
def is_container(field_type: type) -> bool:
|
||||
"""Checks whether the current type is a container type ('contains' other types), like lists and tuples.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
bool: true if a container, false otherwise
|
||||
"""
|
||||
# do not consider strings or byte arrays as containers
|
||||
if field_type in (str, bytes): return False
|
||||
# Early out for standard containers: list, tuple, range
|
||||
from . import lenient_issubclass
|
||||
if lenient_issubclass(field_type, t.Container): return True
|
||||
origin = t.get_origin(field_type)
|
||||
# Early out for non-typing objects
|
||||
if origin is None: return False
|
||||
return lenient_issubclass(origin, t.Container)
|
||||
|
||||
def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType, ...]:
|
||||
"""Parses the arguments inside a container type (lists, tuples and so on).
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
ParamType | tuple[ParamType]: single click-compatible type or a tuple
|
||||
"""
|
||||
if not is_container(field_type):
|
||||
raise ValueError("Field type is not a container type.")
|
||||
args = t.get_args(field_type)
|
||||
# Early out for untyped containers: standard lists, tuples, List[Any]
|
||||
# Use strings when the type is unknown, avoid click's type guessing
|
||||
if len(args) == 0:
|
||||
return click_types.convert_type(str)
|
||||
# Early out for homogenous containers: Tuple[int], List[str]
|
||||
# or homogenous tuples of indefinite length: Tuple[int, ...]
|
||||
if len(args) == 1 or (len(args) == 2 and args[1] is Ellipsis):
|
||||
return parse_single_arg(args[0])
|
||||
# Then deal with fixed-length containers: Tuple[str, int, int]
|
||||
return tuple(parse_single_arg(arg) for arg in args)
|
||||
|
||||
def parse_single_arg(arg: type) -> ParamType:
|
||||
"""Returns the click-compatible type for container origin types.
|
||||
|
||||
In this case, returns string when it's not inferrable, a JSON for mappings
|
||||
and the original type itself in every other case (ints, floats and so on).
|
||||
Bytes is a special case, not natively handled by click.
|
||||
|
||||
Args:
|
||||
arg (type): single argument
|
||||
|
||||
Returns:
|
||||
ParamType: click-compatible type
|
||||
"""
|
||||
from . import lenient_issubclass
|
||||
# When we don't know the type, we choose 'str'
|
||||
if arg is t.Any: return click_types.convert_type(str)
|
||||
# For containers and nested models, we use JSON
|
||||
if is_container(arg): return JsonType()
|
||||
if lenient_issubclass(arg, bytes): return BytesType()
|
||||
return click_types.convert_type(arg)
|
||||
|
||||
class BytesType(ParamType):
|
||||
name = "bytes"
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, bytes): return value
|
||||
try: return str.encode(value)
|
||||
except Exception as exc: self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
|
||||
|
||||
CYGWIN = sys.platform.startswith("cygwin")
|
||||
WIN = sys.platform.startswith("win")
|
||||
if sys.platform.startswith("win") and WIN:
|
||||
def _get_argv_encoding() -> str:
|
||||
import locale
|
||||
return locale.getpreferredencoding()
|
||||
else:
|
||||
def _get_argv_encoding() -> str: return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
|
||||
|
||||
class CudaValueType(ParamType):
|
||||
name = "cuda"
|
||||
envvar_list_splitter = ","
|
||||
is_composite = True
|
||||
typ = click_types.convert_type(str)
|
||||
|
||||
def split_envvar_value(self, rv: str) -> t.Sequence[str]:
|
||||
var = tuple(i for i in rv.split(self.envvar_list_splitter))
|
||||
if "-1" in var:
|
||||
return var[:var.index("-1")]
|
||||
return var
|
||||
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
"""Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
|
||||
|
||||
Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
|
||||
|
||||
Args:
|
||||
ctx: Invocation context for this command.
|
||||
param: The parameter that is requesting completion.
|
||||
incomplete: Value being completed. May be empty.
|
||||
"""
|
||||
from openllm.utils import available_devices
|
||||
mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
|
||||
return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping]
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, bytes):
|
||||
enc = _get_argv_encoding()
|
||||
try: value = value.decode(enc)
|
||||
except UnicodeError:
|
||||
fs_enc = sys.getfilesystemencoding()
|
||||
if fs_enc != enc:
|
||||
try: value = value.decode(fs_enc)
|
||||
except UnicodeError: value = value.decode("utf-8", "replace")
|
||||
else: value = value.decode("utf-8", "replace")
|
||||
return tuple(self.typ(x, param, ctx) for x in value.split(","))
|
||||
|
||||
def __repr__(self) -> str: return "STRING"
|
||||
|
||||
CUDA = CudaValueType()
|
||||
|
||||
class JsonType(ParamType):
|
||||
name = "json"
|
||||
def __init__(self, should_load: bool = True) -> None:
|
||||
"""Support JSON type for click.ParamType.
|
||||
|
||||
Args:
|
||||
should_load: Whether to load the JSON. Default to True. If False, the value won't be converted.
|
||||
"""
|
||||
super().__init__()
|
||||
self.should_load = should_load
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, dict) or not self.should_load: return value
|
||||
try: return orjson.loads(value)
|
||||
except orjson.JSONDecodeError as exc: self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx)
|
||||
@@ -2,7 +2,7 @@
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class FlaxFlanT5(metaclass=_DummyMetaclass):
|
||||
_backends=["flax"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["flax"])
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class ChatGLM(metaclass=_DummyMetaclass):
|
||||
_backends=["torch","cpm_kernels","sentencepiece"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","cpm_kernels","sentencepiece"])
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class TFFlanT5(metaclass=_DummyMetaclass):
|
||||
_backends=["tensorflow"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["tensorflow"])
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class VLLMBaichuan(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm","cpm_kernels","sentencepiece"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","cpm_kernels","sentencepiece"])
|
||||
|
||||
@@ -1,312 +0,0 @@
|
||||
"""Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons."""
|
||||
from __future__ import annotations
|
||||
import importlib, importlib.metadata, importlib.util, logging, os, abc, typing as t
|
||||
from collections import OrderedDict
|
||||
import inflection, packaging.version
|
||||
from bentoml._internal.utils import LazyLoader, pkg
|
||||
from openllm._typing_compat import overload, LiteralString
|
||||
|
||||
from .representation import ReprMixin
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
BackendOrderedDict = OrderedDict[str, t.Tuple[t.Callable[[], bool], str]]
|
||||
from openllm._typing_compat import LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
OPTIONAL_DEPENDENCIES = {"opt", "flan-t5", "vllm", "fine-tune", "ggml", "agents", "openai", "playground", "gptq",}
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
|
||||
|
||||
def _is_package_available(package: str) -> bool:
|
||||
_package_available = importlib.util.find_spec(package) is not None
|
||||
if _package_available:
|
||||
try: importlib.metadata.version(package)
|
||||
except importlib.metadata.PackageNotFoundError: _package_available = False
|
||||
return _package_available
|
||||
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
||||
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
||||
_vllm_available = importlib.util.find_spec("vllm") is not None
|
||||
_peft_available = _is_package_available("peft")
|
||||
_einops_available = _is_package_available("einops")
|
||||
_cpm_kernel_available = _is_package_available("cpm_kernels")
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
_datasets_available = _is_package_available("datasets")
|
||||
_triton_available = _is_package_available("triton")
|
||||
_jupyter_available = _is_package_available("jupyter")
|
||||
_jupytext_available = _is_package_available("jupytext")
|
||||
_notebook_available = _is_package_available("notebook")
|
||||
_autogptq_available = _is_package_available("auto_gptq")
|
||||
_sentencepiece_available = _is_package_available("sentencepiece")
|
||||
_xformers_available = _is_package_available("xformers")
|
||||
_fairscale_available = _is_package_available("fairscale")
|
||||
|
||||
def is_transformers_supports_kbit() -> bool: return pkg.pkg_version_info("transformers")[:2] >= (4, 30)
|
||||
def is_transformers_supports_agent() -> bool: return pkg.pkg_version_info("transformers")[:2] >= (4, 29)
|
||||
def is_jupyter_available() -> bool: return _jupyter_available
|
||||
def is_jupytext_available() -> bool: return _jupytext_available
|
||||
def is_notebook_available() -> bool: return _notebook_available
|
||||
def is_triton_available() -> bool: return _triton_available
|
||||
def is_datasets_available() -> bool: return _datasets_available
|
||||
def is_peft_available() -> bool: return _peft_available
|
||||
def is_einops_available() -> bool: return _einops_available
|
||||
def is_cpm_kernels_available() -> bool: return _cpm_kernel_available
|
||||
def is_bitsandbytes_available() -> bool: return _bitsandbytes_available
|
||||
def is_autogptq_available() -> bool: return _autogptq_available
|
||||
def is_vllm_available() -> bool: return _vllm_available
|
||||
def is_sentencepiece_available() -> bool: return _sentencepiece_available
|
||||
def is_xformers_available() -> bool: return _xformers_available
|
||||
def is_fairscale_available() -> bool: return _fairscale_available
|
||||
def is_torch_available() -> bool:
|
||||
global _torch_available
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
if _torch_available:
|
||||
try: importlib.metadata.version("torch")
|
||||
except importlib.metadata.PackageNotFoundError: _torch_available = False
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TF is set")
|
||||
_torch_available = False
|
||||
return _torch_available
|
||||
def is_tf_available() -> bool:
|
||||
global _tf_available
|
||||
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: _tf_available = True
|
||||
else:
|
||||
_tf_version = None
|
||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||
if _tf_available:
|
||||
candidates = ("tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos", "tensorflow-aarch64",)
|
||||
_tf_version = None
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
for _pkg in candidates:
|
||||
try:
|
||||
_tf_version = importlib.metadata.version(_pkg)
|
||||
break
|
||||
except importlib.metadata.PackageNotFoundError: pass # noqa: PERF203 # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
|
||||
_tf_available = _tf_version is not None
|
||||
if _tf_available:
|
||||
if _tf_version and packaging.version.parse(_tf_version) < packaging.version.parse("2"):
|
||||
logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version)
|
||||
_tf_available = False
|
||||
else:
|
||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||
_tf_available = False
|
||||
return _tf_available
|
||||
def is_flax_available() -> bool:
|
||||
global _flax_available
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
if _flax_available:
|
||||
try:
|
||||
importlib.metadata.version("jax")
|
||||
importlib.metadata.version("flax")
|
||||
except importlib.metadata.PackageNotFoundError: _flax_available = False
|
||||
else:
|
||||
_flax_available = False
|
||||
return _flax_available
|
||||
|
||||
VLLM_IMPORT_ERROR_WITH_PYTORCH = """\
|
||||
{0} requires the vLLM library but it was not found in your environment.
|
||||
However, we were able to find a PyTorch installation. PyTorch classes do not begin
|
||||
with "VLLM", but are otherwise identically named to our PyTorch classes.
|
||||
If you want to use PyTorch, please use those classes instead!
|
||||
|
||||
If you really do want to use vLLM, please follow the instructions on the
|
||||
installation page https://github.com/vllm-project/vllm that match your environment.
|
||||
"""
|
||||
VLLM_IMPORT_ERROR_WITH_TF = """\
|
||||
{0} requires the vLLM library but it was not found in your environment.
|
||||
However, we were able to find a TensorFlow installation. TensorFlow classes begin
|
||||
with "TF", but are otherwise identically named to the PyTorch classes. This
|
||||
means that the TF equivalent of the class you tried to import would be "TF{0}".
|
||||
If you want to use TensorFlow, please use TF classes instead!
|
||||
|
||||
If you really do want to use vLLM, please follow the instructions on the
|
||||
installation page https://github.com/vllm-project/vllm that match your environment.
|
||||
"""
|
||||
VLLM_IMPORT_ERROR_WITH_FLAX = """\
|
||||
{0} requires the vLLM library but it was not found in your environment.
|
||||
However, we were able to find a Flax installation. Flax classes begin
|
||||
with "Flax", but are otherwise identically named to the PyTorch classes. This
|
||||
means that the Flax equivalent of the class you tried to import would be "Flax{0}".
|
||||
If you want to use Flax, please use Flax classes instead!
|
||||
|
||||
If you really do want to use vLLM, please follow the instructions on the
|
||||
installation page https://github.com/vllm-project/vllm that match your environment.
|
||||
"""
|
||||
PYTORCH_IMPORT_ERROR_WITH_TF = """\
|
||||
{0} requires the PyTorch library but it was not found in your environment.
|
||||
However, we were able to find a TensorFlow installation. TensorFlow classes begin
|
||||
with "TF", but are otherwise identically named to the PyTorch classes. This
|
||||
means that the TF equivalent of the class you tried to import would be "TF{0}".
|
||||
If you want to use TensorFlow, please use TF classes instead!
|
||||
|
||||
If you really do want to use PyTorch please go to
|
||||
https://pytorch.org/get-started/locally/ and follow the instructions that
|
||||
match your environment.
|
||||
"""
|
||||
TF_IMPORT_ERROR_WITH_PYTORCH = """\
|
||||
{0} requires the TensorFlow library but it was not found in your environment.
|
||||
However, we were able to find a PyTorch installation. PyTorch classes do not begin
|
||||
with "TF", but are otherwise identically named to our TF classes.
|
||||
If you want to use PyTorch, please use those classes instead!
|
||||
|
||||
If you really do want to use TensorFlow, please follow the instructions on the
|
||||
installation page https://www.tensorflow.org/install that match your environment.
|
||||
"""
|
||||
TENSORFLOW_IMPORT_ERROR = """{0} requires the TensorFlow library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://www.tensorflow.org/install and follow the
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
FLAX_IMPORT_ERROR = """{0} requires the FLAX library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://github.com/google/flax and follow the
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
PYTORCH_IMPORT_ERROR = """{0} requires the PyTorch library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
VLLM_IMPORT_ERROR = """{0} requires the vLLM library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://github.com/vllm-project/vllm
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
CPM_KERNELS_IMPORT_ERROR = """{0} requires the cpm_kernels library but it was not found in your environment.
|
||||
You can install it with pip: `pip install cpm_kernels`. Please note that you may need to restart your
|
||||
runtime after installation.
|
||||
"""
|
||||
EINOPS_IMPORT_ERROR = """{0} requires the einops library but it was not found in your environment.
|
||||
You can install it with pip: `pip install einops`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
TRITON_IMPORT_ERROR = """{0} requires the triton library but it was not found in your environment.
|
||||
You can install it with pip: 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'.
|
||||
Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
DATASETS_IMPORT_ERROR = """{0} requires the datasets library but it was not found in your environment.
|
||||
You can install it with pip: `pip install datasets`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
PEFT_IMPORT_ERROR = """{0} requires the peft library but it was not found in your environment.
|
||||
You can install it with pip: `pip install peft`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
BITSANDBYTES_IMPORT_ERROR = """{0} requires the bitsandbytes library but it was not found in your environment.
|
||||
You can install it with pip: `pip install bitsandbytes`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
AUTOGPTQ_IMPORT_ERROR = """{0} requires the auto-gptq library but it was not found in your environment.
|
||||
You can install it with pip: `pip install auto-gptq`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
SENTENCEPIECE_IMPORT_ERROR = """{0} requires the sentencepiece library but it was not found in your environment.
|
||||
You can install it with pip: `pip install sentencepiece`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
XFORMERS_IMPORT_ERROR = """{0} requires the xformers library but it was not found in your environment.
|
||||
You can install it with pip: `pip install xformers`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
FAIRSCALE_IMPORT_ERROR = """{0} requires the fairscale library but it was not found in your environment.
|
||||
You can install it with pip: `pip install fairscale`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
|
||||
BACKENDS_MAPPING: BackendOrderedDict = OrderedDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||
("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)),
|
||||
("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
|
||||
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
|
||||
("xformers", (is_xformers_available, XFORMERS_IMPORT_ERROR)), ("fairscale", (is_fairscale_available, FAIRSCALE_IMPORT_ERROR))])
|
||||
|
||||
class DummyMetaclass(abc.ABCMeta):
|
||||
"""Metaclass for dummy object.
|
||||
|
||||
It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class.
|
||||
"""
|
||||
_backends: t.List[str]
|
||||
def __getattribute__(cls, key: str) -> t.Any:
|
||||
if key.startswith("_"): return super().__getattribute__(key)
|
||||
require_backends(cls, cls._backends)
|
||||
|
||||
def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None:
|
||||
if not isinstance(backends, (list, tuple)): backends = list(backends)
|
||||
name = o.__name__ if hasattr(o, "__name__") else o.__class__.__name__
|
||||
# Raise an error for users who might not realize that classes without "TF" are torch-only
|
||||
if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
|
||||
# Raise the inverse error for PyTorch users trying to load TF classes
|
||||
if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
# Raise an error when vLLM is not available to consider the alternative, order from PyTorch -> Tensorflow -> Flax
|
||||
if "vllm" in backends:
|
||||
if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
|
||||
if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
|
||||
failed = [msg.format(name) for available, msg in (BACKENDS_MAPPING[backend] for backend in backends) if not available()]
|
||||
if failed: raise ImportError("".join(failed))
|
||||
|
||||
class EnvVarMixin(ReprMixin):
|
||||
model_name: str
|
||||
config: str
|
||||
model_id: str
|
||||
quantize: str
|
||||
framework: str
|
||||
bettertransformer: str
|
||||
runtime: str
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["config"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["model_id"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["quantize"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["framework"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["runtime"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["quantize_value"]) -> t.Literal["int8", "int4", "gptq"] | None: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ...
|
||||
def __getitem__(self, item: str | t.Any) -> t.Any:
|
||||
if item.endswith("_value") and hasattr(self, f"_{item}"): return object.__getattribute__(self, f"_{item}")()
|
||||
elif hasattr(self, item): return getattr(self, item)
|
||||
raise KeyError(f"Key {item} not found in {self}")
|
||||
def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers") -> None:
|
||||
"""EnvVarMixin is a mixin class that returns the value extracted from environment variables."""
|
||||
from openllm._configuration import field_env_key
|
||||
self.model_name = inflection.underscore(model_name)
|
||||
self._implementation = implementation
|
||||
self._model_id = model_id
|
||||
self._bettertransformer = bettertransformer
|
||||
self._quantize = quantize
|
||||
self._runtime = runtime
|
||||
for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: setattr(self, att, field_env_key(self.model_name, att.upper()))
|
||||
def _quantize_value(self) -> t.Literal["int8", "int4", "gptq"] | None:
|
||||
from . import first_not_none
|
||||
return t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], first_not_none(os.environ.get(self["quantize"]), default=self._quantize))
|
||||
def _framework_value(self) -> LiteralRuntime:
|
||||
from . import first_not_none
|
||||
return t.cast(t.Literal["pt", "tf", "flax", "vllm"], first_not_none(os.environ.get(self["framework"]), default=self._implementation))
|
||||
def _bettertransformer_value(self) -> bool:
|
||||
from . import first_not_none
|
||||
return t.cast(bool, first_not_none(os.environ.get(self["bettertransformer"], str(False)).upper() in ENV_VARS_TRUE_VALUES, default=self._bettertransformer))
|
||||
def _model_id_value(self) -> str | None:
|
||||
from . import first_not_none
|
||||
return first_not_none(os.environ.get(self["model_id"]), default=self._model_id)
|
||||
def _runtime_value(self) -> t.Literal["ggml", "transformers"]:
|
||||
from . import first_not_none
|
||||
return t.cast(t.Literal["ggml", "transformers"], first_not_none(os.environ.get(self["runtime"]), default=self._runtime))
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]: return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}
|
||||
@property
|
||||
def start_docstring(self) -> str: return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
|
||||
@property
|
||||
def module(self) -> LazyLoader: return LazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")
|
||||
@@ -1,107 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import functools, importlib, importlib.machinery, importlib.metadata, importlib.util, itertools, os, time, types, warnings, typing as t
|
||||
import attr, openllm
|
||||
|
||||
__all__ = ["VersionInfo", "LazyModule"]
|
||||
# vendorred from attrs
|
||||
@functools.total_ordering
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True, repr=False)
|
||||
class VersionInfo:
|
||||
major: int = attr.field()
|
||||
minor: int = attr.field()
|
||||
micro: int = attr.field()
|
||||
releaselevel: str = attr.field()
|
||||
@classmethod
|
||||
def from_version_string(cls, s: str) -> VersionInfo:
|
||||
v = s.split(".")
|
||||
if len(v) == 3: v.append("final")
|
||||
return cls(major=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3])
|
||||
def _ensure_tuple(self, other: VersionInfo) -> tuple[tuple[int, int, int, str], tuple[int, int, int, str]]:
|
||||
cmp = attr.astuple(other) if self.__class__ is other.__class__ else other
|
||||
if not isinstance(cmp, tuple): raise NotImplementedError
|
||||
if not (1 <= len(cmp) <= 4): raise NotImplementedError
|
||||
return t.cast(t.Tuple[int, int, int, str], attr.astuple(self)[:len(cmp)]), t.cast(t.Tuple[int, int, int, str], cmp)
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
try: us, them = self._ensure_tuple(other)
|
||||
except NotImplementedError: return NotImplemented
|
||||
return us == them
|
||||
def __lt__(self, other: t.Any) -> bool:
|
||||
try: us, them = self._ensure_tuple(other)
|
||||
except NotImplementedError: return NotImplemented
|
||||
# Since alphabetically "dev0" < "final" < "post1" < "post2", we don't have to do anything special with releaselevel for now.
|
||||
return us < them
|
||||
def __repr__(self) -> str: return "{0}.{1}.{2}".format(*attr.astuple(self)[:3])
|
||||
|
||||
_sentinel, _reserved_namespace = object(), {"__openllm_migration__"}
|
||||
class LazyModule(types.ModuleType):
|
||||
# Very heavily inspired by optuna.integration._IntegrationModule: https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
||||
def __init__(self, name: str, module_file: str, import_structure: dict[str, list[str]], module_spec: importlib.machinery.ModuleSpec | None = None, doc: str | None = None, extra_objects: dict[str, t.Any] | None = None):
|
||||
"""Lazily load this module as an object.
|
||||
|
||||
It does instantiate a __all__ and __dir__ for IDE support
|
||||
|
||||
Args:
|
||||
name: module name
|
||||
module_file: the given file. Often default to 'globals()['__file__']'
|
||||
import_structure: A dictionary of module and its corresponding attributes that can be loaded from given 'module'
|
||||
module_spec: __spec__ of the lazily loaded module
|
||||
doc: Optional docstring for this module.
|
||||
extra_objects: Any additional objects that this module can also be accessed. Useful for additional metadata as well as any locals() functions
|
||||
"""
|
||||
super().__init__(name)
|
||||
self._modules = set(import_structure.keys())
|
||||
self._class_to_module: dict[str, str] = {}
|
||||
_extra_objects = {} if extra_objects is None else extra_objects
|
||||
for key, values in import_structure.items():
|
||||
for value in values: self._class_to_module[value] = key
|
||||
# Needed for autocompletion in an IDE
|
||||
self.__all__: list[str] = list(import_structure.keys()) + list(itertools.chain(*import_structure.values()))
|
||||
self.__file__ = module_file
|
||||
self.__spec__ = module_spec or importlib.util.find_spec(name)
|
||||
self.__path__ = [os.path.dirname(module_file)]
|
||||
self.__doc__ = doc
|
||||
self._name = name
|
||||
self._objects = _extra_objects
|
||||
self._import_structure = import_structure
|
||||
def __dir__(self) -> list[str]:
|
||||
result = t.cast("list[str]", super().__dir__())
|
||||
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
|
||||
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
|
||||
return result + [i for i in self.__all__ if i not in result]
|
||||
def __getattr__(self, name: str) -> t.Any:
|
||||
"""Equivocal __getattr__ implementation.
|
||||
|
||||
It checks from _objects > _modules and does it recursively.
|
||||
|
||||
It also contains a special case for all of the metadata information, such as __version__ and __version_info__.
|
||||
"""
|
||||
if name in _reserved_namespace: raise openllm.exceptions.ForbiddenAttributeError(f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified.")
|
||||
dunder_to_metadata = {"__title__": "Name", "__copyright__": "", "__version__": "version", "__version_info__": "version", "__description__": "summary", "__uri__": "", "__url__": "", "__author__": "", "__email__": "", "__license__": "license", "__homepage__": ""}
|
||||
if name in dunder_to_metadata:
|
||||
if name not in {"__version_info__", "__copyright__", "__version__"}: warnings.warn(f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2)
|
||||
meta = importlib.metadata.metadata("openllm")
|
||||
project_url = dict(url.split(", ") for url in t.cast(t.List[str], meta.get_all("Project-URL")))
|
||||
if name == "__license__": return "Apache-2.0"
|
||||
elif name == "__copyright__": return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
|
||||
elif name in ("__uri__", "__url__"): return project_url["GitHub"]
|
||||
elif name == "__homepage__": return project_url["Homepage"]
|
||||
elif name == "__version_info__": return VersionInfo.from_version_string(meta["version"]) # similar to how attrs handle __version_info__
|
||||
elif name == "__author__": return meta["Author-email"].rsplit(" ", 1)[0]
|
||||
elif name == "__email__": return meta["Author-email"].rsplit("<", 1)[1][:-1]
|
||||
return meta[dunder_to_metadata[name]]
|
||||
if "__openllm_migration__" in self._objects:
|
||||
cur_value = self._objects["__openllm_migration__"].get(name, _sentinel)
|
||||
if cur_value is not _sentinel:
|
||||
warnings.warn(f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, stacklevel=3)
|
||||
return getattr(self, cur_value)
|
||||
if name in self._objects: return self._objects.__getitem__(name)
|
||||
if name in self._modules: value = self._get_module(name)
|
||||
elif name in self._class_to_module.keys(): value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name)
|
||||
else: raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
||||
setattr(self, name, value)
|
||||
return value
|
||||
def _get_module(self, module_name: str) -> types.ModuleType:
|
||||
try: return importlib.import_module("." + module_name, self.__name__)
|
||||
except Exception as e: raise RuntimeError(f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}") from e
|
||||
# make sure this module is picklable
|
||||
def __reduce__(self) -> tuple[type[LazyModule], tuple[str, str | None, dict[str, list[str]]]]: return (self.__class__, (self._name, self.__file__, self._import_structure))
|
||||
@@ -1,32 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from abc import abstractmethod
|
||||
import attr, orjson
|
||||
from openllm import utils
|
||||
if t.TYPE_CHECKING: from openllm._typing_compat import TypeAlias
|
||||
|
||||
ReprArgs: TypeAlias = t.Generator[t.Tuple[t.Optional[str], t.Any], None, None]
|
||||
class ReprMixin:
|
||||
@property
|
||||
@abstractmethod
|
||||
def __repr_keys__(self) -> set[str]: raise NotImplementedError
|
||||
"""This can be overriden by base class using this mixin."""
|
||||
def __repr__(self) -> str: return f"{self.__class__.__name__} {orjson.dumps({k: utils.bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}"
|
||||
"""The `__repr__` for any subclass of Mixin.
|
||||
|
||||
It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
|
||||
"""
|
||||
def __str__(self) -> str: return self.__repr_str__(" ")
|
||||
"""The string representation of the given Mixin subclass.
|
||||
|
||||
It will contains all of the attributes from __repr_keys__
|
||||
"""
|
||||
def __repr_name__(self) -> str: return self.__class__.__name__
|
||||
"""Name of the instance's class, used in __repr__."""
|
||||
def __repr_str__(self, join_str: str) -> str: return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
|
||||
"""To be used with __str__."""
|
||||
def __repr_args__(self) -> ReprArgs: return ((k, getattr(self, k)) for k in self.__repr_keys__)
|
||||
"""This can also be overriden by base class using this mixin.
|
||||
|
||||
By default it does a getattr of the current object from __repr_keys__.
|
||||
"""
|
||||
Reference in New Issue
Block a user