mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-28 01:14:09 -05:00
perf: unify LLM interface (#518)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""OpenLLM.
|
||||
'''OpenLLM.
|
||||
|
||||
An open platform for operating large language models in production. Fine-tune, serve,
|
||||
deploy, and monitor any LLMs with ease.
|
||||
@@ -7,16 +7,40 @@ deploy, and monitor any LLMs with ease.
|
||||
* Option to bring your own fine-tuned LLMs
|
||||
* Online Serving with HTTP, gRPC, SSE(coming soon) or custom API
|
||||
* Native integration with BentoML and LangChain for custom LLM apps
|
||||
"""
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import logging as _logging, os as _os, typing as _t, warnings as _warnings, openllm_core
|
||||
from pathlib import Path as _Path
|
||||
from . import exceptions as exceptions, utils as utils
|
||||
import logging as _logging
|
||||
import os as _os
|
||||
import typing as _t
|
||||
import warnings as _warnings
|
||||
|
||||
from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams
|
||||
from openllm_core._strategies import CascadingResourceStrategy as CascadingResourceStrategy, get_resource as get_resource
|
||||
from openllm_core._schema import GenerateInput as GenerateInput, GenerateOutput as GenerateOutput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput
|
||||
from openllm_core.config import AutoConfig as AutoConfig, CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig
|
||||
from pathlib import Path as _Path
|
||||
|
||||
import openllm_core
|
||||
|
||||
from openllm_core._configuration import GenerationConfig as GenerationConfig
|
||||
from openllm_core._configuration import LLMConfig as LLMConfig
|
||||
from openllm_core._configuration import SamplingParams as SamplingParams
|
||||
from openllm_core._schemas import GenerationInput as GenerationInput
|
||||
from openllm_core._schemas import GenerationOutput as GenerationOutput
|
||||
from openllm_core._schemas import MetadataOutput as MetadataOutput
|
||||
from openllm_core.config import CONFIG_MAPPING as CONFIG_MAPPING
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
|
||||
from openllm_core.config import AutoConfig as AutoConfig
|
||||
from openllm_core.config import BaichuanConfig as BaichuanConfig
|
||||
from openllm_core.config import ChatGLMConfig as ChatGLMConfig
|
||||
from openllm_core.config import DollyV2Config as DollyV2Config
|
||||
from openllm_core.config import FalconConfig as FalconConfig
|
||||
from openllm_core.config import FlanT5Config as FlanT5Config
|
||||
from openllm_core.config import GPTNeoXConfig as GPTNeoXConfig
|
||||
from openllm_core.config import LlamaConfig as LlamaConfig
|
||||
from openllm_core.config import MPTConfig as MPTConfig
|
||||
from openllm_core.config import OPTConfig as OPTConfig
|
||||
from openllm_core.config import StableLMConfig as StableLMConfig
|
||||
from openllm_core.config import StarCoderConfig as StarCoderConfig
|
||||
|
||||
from . import exceptions as exceptions
|
||||
from . import utils as utils
|
||||
|
||||
if openllm_core.utils.DEBUG:
|
||||
openllm_core.utils.set_debug_mode(True)
|
||||
@@ -24,163 +48,64 @@ if openllm_core.utils.DEBUG:
|
||||
_logging.basicConfig(level=_logging.NOTSET)
|
||||
else:
|
||||
# configuration for bitsandbytes before import
|
||||
_os.environ["BITSANDBYTES_NOWELCOME"] = _os.environ.get("BITSANDBYTES_NOWELCOME", "1")
|
||||
_os.environ['BITSANDBYTES_NOWELCOME'] = _os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
# NOTE: The following warnings from bitsandbytes, and probably not that important for users to see when DEBUG is False
|
||||
_warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization")
|
||||
_warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")
|
||||
_warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support.")
|
||||
_warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization')
|
||||
_warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
|
||||
_warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.')
|
||||
# NOTE: ignore the following warning from ghapi as it is not important for users
|
||||
_warnings.filterwarnings("ignore", message="Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated")
|
||||
_warnings.filterwarnings('ignore', message='Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated')
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"exceptions": [],
|
||||
"models": [],
|
||||
"client": [],
|
||||
"bundle": [],
|
||||
"playground": [],
|
||||
"testing": [],
|
||||
"prompts": ["PromptTemplate"],
|
||||
"protocol": ["openai"],
|
||||
"utils": ["infer_auto_class"],
|
||||
"serialisation": ["ggml", "transformers"],
|
||||
"cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"],
|
||||
"_quantisation": ["infer_quantisation_config"],
|
||||
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"],
|
||||
"_generation": ["StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList", "prepare_logits_processor"],
|
||||
"models.auto": ["MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"],
|
||||
"models.chatglm": [],
|
||||
"models.baichuan": [],
|
||||
"models.dolly_v2": [],
|
||||
"models.falcon": [],
|
||||
"models.flan_t5": [],
|
||||
"models.gpt_neox": [],
|
||||
"models.llama": [],
|
||||
"models.mpt": [],
|
||||
"models.opt": [],
|
||||
"models.stablelm": [],
|
||||
"models.starcoder": []
|
||||
'exceptions': [],
|
||||
'client': [],
|
||||
'bundle': [],
|
||||
'playground': [],
|
||||
'testing': [],
|
||||
'prompts': ['PromptTemplate'],
|
||||
'protocol': [],
|
||||
'utils': [],
|
||||
'_deprecated': ['Runner'],
|
||||
'entrypoints': ['mount_entrypoints'],
|
||||
'serialisation': ['ggml', 'transformers'],
|
||||
'cli._sdk': ['start', 'start_grpc', 'build', 'import_model', 'list_models'],
|
||||
'_quantisation': ['infer_quantisation_config'],
|
||||
'_llm': ['LLM', 'LLMRunner', 'LLMRunnable'],
|
||||
'_generation': ['StopSequenceCriteria', 'StopOnTokens', 'LogitsProcessorList', 'StoppingCriteriaList', 'prepare_logits_processor'],
|
||||
}
|
||||
COMPILED = _Path(__file__).suffix in (".pyd", ".so")
|
||||
COMPILED = _Path(__file__).suffix in ('.pyd', '.so')
|
||||
|
||||
if _t.TYPE_CHECKING:
|
||||
from . import bundle as bundle, cli as cli, client as client, models as models, playground as playground, serialisation as serialisation, testing as testing
|
||||
from ._generation import LogitsProcessorList as LogitsProcessorList, StopOnTokens as StopOnTokens, StoppingCriteriaList as StoppingCriteriaList, StopSequenceCriteria as StopSequenceCriteria, prepare_logits_processor as prepare_logits_processor
|
||||
from ._llm import LLM as LLM, LLMRunnable as LLMRunnable, LLMRunner as LLMRunner, Runner as Runner
|
||||
from . import bundle as bundle
|
||||
from . import cli as cli
|
||||
from . import client as client
|
||||
from . import playground as playground
|
||||
from . import serialisation as serialisation
|
||||
from . import testing as testing
|
||||
from . import utils as utils
|
||||
from ._generation import LogitsProcessorList as LogitsProcessorList
|
||||
from ._generation import StopOnTokens as StopOnTokens
|
||||
from ._generation import StoppingCriteriaList as StoppingCriteriaList
|
||||
from ._generation import StopSequenceCriteria as StopSequenceCriteria
|
||||
from ._generation import prepare_logits_processor as prepare_logits_processor
|
||||
from ._llm import LLM as LLM
|
||||
from ._llm import LLMRunnable as LLMRunnable
|
||||
from ._llm import LLMRunner as LLMRunner
|
||||
from ._quantisation import infer_quantisation_config as infer_quantisation_config
|
||||
from .cli._sdk import build as build, import_model as import_model, list_models as list_models, start as start, start_grpc as start_grpc
|
||||
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
|
||||
from .serialisation import ggml as ggml, transformers as transformers
|
||||
from ._deprecated import Runner as Runner
|
||||
from .cli._sdk import build as build
|
||||
from .cli._sdk import import_model as import_model
|
||||
from .cli._sdk import list_models as list_models
|
||||
from .cli._sdk import start as start
|
||||
from .cli._sdk import start_grpc as start_grpc
|
||||
from .prompts import PromptTemplate as PromptTemplate
|
||||
from .protocol import openai as openai
|
||||
from .utils import infer_auto_class as infer_auto_class
|
||||
|
||||
try:
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_cpm_kernels_available()):
|
||||
raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_pt_objects"] = ["ChatGLM", "Baichuan"]
|
||||
else:
|
||||
_import_structure["models.chatglm"].extend(["ChatGLM"])
|
||||
_import_structure["models.baichuan"].extend(["Baichuan"])
|
||||
if _t.TYPE_CHECKING:
|
||||
from .models.baichuan import Baichuan as Baichuan
|
||||
from .models.chatglm import ChatGLM as ChatGLM
|
||||
try:
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_triton_available()):
|
||||
raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["MPT"])
|
||||
else: _import_structure["utils.dummy_pt_objects"] = ["MPT"]
|
||||
else:
|
||||
_import_structure["models.mpt"].extend(["MPT"])
|
||||
if _t.TYPE_CHECKING: from .models.mpt import MPT as MPT
|
||||
try:
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_einops_available()):
|
||||
raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["Falcon"])
|
||||
else: _import_structure["utils.dummy_pt_objects"] = ["Falcon"]
|
||||
else:
|
||||
_import_structure["models.falcon"].extend(["Falcon"])
|
||||
if _t.TYPE_CHECKING: from .models.falcon import Falcon as Falcon
|
||||
|
||||
try:
|
||||
if not openllm_core.utils.is_torch_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_pt_objects"] = [
|
||||
name for name in dir(utils.dummy_pt_objects) if not name.startswith("_") and name not in ("ChatGLM", "Baichuan", "MPT", "Falcon", "annotations")
|
||||
]
|
||||
else:
|
||||
_import_structure["models.flan_t5"].extend(["FlanT5"])
|
||||
_import_structure["models.dolly_v2"].extend(["DollyV2"])
|
||||
_import_structure["models.starcoder"].extend(["StarCoder"])
|
||||
_import_structure["models.stablelm"].extend(["StableLM"])
|
||||
_import_structure["models.opt"].extend(["OPT"])
|
||||
_import_structure["models.gpt_neox"].extend(["GPTNeoX"])
|
||||
_import_structure["models.llama"].extend(["Llama"])
|
||||
_import_structure["models.auto"].extend(["AutoLLM", "MODEL_MAPPING"])
|
||||
if _t.TYPE_CHECKING:
|
||||
from .models.auto import MODEL_MAPPING as MODEL_MAPPING, AutoLLM as AutoLLM
|
||||
from .models.dolly_v2 import DollyV2 as DollyV2
|
||||
from .models.flan_t5 import FlanT5 as FlanT5
|
||||
from .models.gpt_neox import GPTNeoX as GPTNeoX
|
||||
from .models.llama import Llama as Llama
|
||||
from .models.opt import OPT as OPT
|
||||
from .models.stablelm import StableLM as StableLM
|
||||
from .models.starcoder import StarCoder as StarCoder
|
||||
try:
|
||||
if not openllm_core.utils.is_vllm_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_vllm_objects"] = [name for name in dir(utils.dummy_vllm_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
else:
|
||||
_import_structure["models.baichuan"].extend(["VLLMBaichuan"])
|
||||
_import_structure["models.llama"].extend(["VLLMLlama"])
|
||||
_import_structure["models.opt"].extend(["VLLMOPT"])
|
||||
_import_structure["models.dolly_v2"].extend(["VLLMDollyV2"])
|
||||
_import_structure["models.falcon"].extend(["VLLMFalcon"])
|
||||
_import_structure["models.gpt_neox"].extend(["VLLMGPTNeoX"])
|
||||
_import_structure["models.mpt"].extend(["VLLMMPT"])
|
||||
_import_structure["models.stablelm"].extend(["VLLMStableLM"])
|
||||
_import_structure["models.starcoder"].extend(["VLLMStarCoder"])
|
||||
_import_structure["models.auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
|
||||
if _t.TYPE_CHECKING:
|
||||
from .models.auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING, AutoVLLM as AutoVLLM
|
||||
from .models.baichuan import VLLMBaichuan as VLLMBaichuan
|
||||
from .models.dolly_v2 import VLLMDollyV2 as VLLMDollyV2
|
||||
from .models.gpt_neox import VLLMGPTNeoX as VLLMGPTNeoX
|
||||
from .models.falcon import VLLMFalcon as VLLMFalcon
|
||||
from .models.llama import VLLMLlama as VLLMLlama
|
||||
from .models.mpt import VLLMMPT as VLLMMPT
|
||||
from .models.opt import VLLMOPT as VLLMOPT
|
||||
from .models.stablelm import VLLMStableLM as VLLMStableLM
|
||||
from .models.starcoder import VLLMStarCoder as VLLMStarCoder
|
||||
try:
|
||||
if not openllm_core.utils.is_flax_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_flax_objects"] = [name for name in dir(utils.dummy_flax_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
else:
|
||||
_import_structure["models.flan_t5"].extend(["FlaxFlanT5"])
|
||||
_import_structure["models.opt"].extend(["FlaxOPT"])
|
||||
_import_structure["models.auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
|
||||
if _t.TYPE_CHECKING:
|
||||
from .models.auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM
|
||||
from .models.flan_t5 import FlaxFlanT5 as FlaxFlanT5
|
||||
from .models.opt import FlaxOPT as FlaxOPT
|
||||
try:
|
||||
if not openllm_core.utils.is_tf_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_tf_objects"] = [name for name in dir(utils.dummy_tf_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
else:
|
||||
_import_structure["models.flan_t5"].extend(["TFFlanT5"])
|
||||
_import_structure["models.opt"].extend(["TFOPT"])
|
||||
_import_structure["models.auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
|
||||
if _t.TYPE_CHECKING:
|
||||
from .models.auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM
|
||||
from .models.flan_t5 import TFFlanT5 as TFFlanT5
|
||||
from .models.opt import TFOPT as TFOPT
|
||||
from .serialisation import ggml as ggml
|
||||
from .serialisation import transformers as transformers
|
||||
from .entrypoints import mount_entrypoints as mount_entrypoints
|
||||
|
||||
# NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__
|
||||
__lazy = openllm_core.utils.LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects={"COMPILED": COMPILED})
|
||||
__lazy = openllm_core.utils.LazyModule(__name__, globals()['__file__'], _import_structure, extra_objects={'COMPILED': COMPILED})
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
'''LLM assignment magik.'''
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
from openllm.exceptions import OpenLLMException
|
||||
from openllm_core._configuration import _object_getattribute
|
||||
from openllm_core._configuration import _setattr_class
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import ListStr
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
from openllm_core._typing_compat import import_model_protocol
|
||||
from openllm_core._typing_compat import llm_post_init_protocol
|
||||
from openllm_core._typing_compat import load_model_protocol
|
||||
from openllm_core._typing_compat import load_tokenizer_protocol
|
||||
from openllm_core.utils import LazyLoader
|
||||
from openllm_core.utils import codegen
|
||||
from openllm_core.utils import device_count
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import get_debug_mode
|
||||
from openllm_core.utils import is_torch_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
import vllm
|
||||
|
||||
import bentoml
|
||||
|
||||
from openllm._llm import LLM
|
||||
else:
|
||||
torch = LazyLoader('torch', globals(), 'torch')
|
||||
vllm = LazyLoader('vllm', globals(), 'vllm')
|
||||
|
||||
def import_model(fn: import_model_protocol[bentoml.Model, M, T]) -> t.Callable[[LLM[M, T]], bentoml.Model]:
|
||||
@functools.wraps(fn)
|
||||
def inner(self: LLM[M, T], *decls: t.Any, trust_remote_code: bool | None = None, **attrs: t.Any) -> bentoml.Model:
|
||||
(model_decls, model_attrs), _ = self.llm_parameters
|
||||
decls = (*model_decls, *decls)
|
||||
attrs = {**model_attrs, **attrs}
|
||||
return fn(self, *decls, trust_remote_code=first_not_none(trust_remote_code, default=self.trust_remote_code), **attrs)
|
||||
|
||||
return inner
|
||||
|
||||
def load_model(fn: load_model_protocol[M, T]) -> t.Callable[[LLM[M, T]], M | vllm.AsyncLLMEngine]:
|
||||
@functools.wraps(fn)
|
||||
def inner(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.AsyncLLMEngine:
|
||||
if self.__llm_backend__ == 'vllm':
|
||||
num_gpus, dev = 1, device_count()
|
||||
if dev >= 2: num_gpus = min(dev // 2 * 2, dev)
|
||||
try:
|
||||
return vllm.AsyncLLMEngine.from_engine_args(
|
||||
vllm.AsyncEngineArgs(model=self._bentomodel.path,
|
||||
tokenizer=self._bentomodel.path if self.tokenizer_id == 'local' else self.tokenizer_id,
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus,
|
||||
dtype='auto',
|
||||
disable_log_requests=not get_debug_mode(),
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False))
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from None
|
||||
else:
|
||||
(model_decls, model_attrs), _ = self.llm_parameters
|
||||
decls = (*model_decls, *decls)
|
||||
attrs = {**model_attrs, **attrs}
|
||||
return fn(self, *decls, **attrs)
|
||||
|
||||
return inner
|
||||
|
||||
def load_tokenizer(fn: load_tokenizer_protocol[M, T]) -> t.Callable[[LLM[M, T]], T]:
|
||||
@functools.wraps(fn)
|
||||
def inner(self: LLM[M, T], **tokenizer_attrs: t.Any) -> T:
|
||||
return fn(self, **{**self.llm_parameters[-1], **tokenizer_attrs})
|
||||
|
||||
return inner
|
||||
|
||||
def llm_post_init(fn: llm_post_init_protocol[M, T]) -> t.Callable[[LLM[M, T]], None]:
|
||||
@functools.wraps(fn)
|
||||
def inner(self: LLM[M, T]) -> None:
|
||||
if self.__llm_backend__ == 'pt' and is_torch_available():
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
fn(self)
|
||||
|
||||
return inner
|
||||
|
||||
def make_llm_attributes(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]], None]:
|
||||
'''Make LLM attributes for the given LLM subclass.'''
|
||||
from ._llm import LLM
|
||||
from ._llm import LLMFunction
|
||||
from ._llm import LLMInterface
|
||||
from ._llm import LLMSerialisation
|
||||
|
||||
args: ListStr = []
|
||||
globs: DictStrAny = {'cls': cls, '__wrapped_llm_post_init': llm_post_init, 'LLM': LLM}
|
||||
# _cached_LLMFunction_get and _ccached_LLMSerialisation_get
|
||||
globs.update({f'_cached_{cl_.__name__}_get': _object_getattribute.__get__(cl_) for cl_ in {LLMSerialisation, LLMFunction}})
|
||||
# llm_post_init implementation
|
||||
lines: ListStr = [f'_impl_{cls.__name__}_func=cls.llm_post_init', _setattr_class('llm_post_init', f'__wrapped_llm_post_init(_impl_{cls.__name__}_func)')]
|
||||
|
||||
serialisation_attr = {'import_model': import_model, 'load_model': load_model, 'load_tokenizer': load_tokenizer,}
|
||||
for func, impl in serialisation_attr.items():
|
||||
impl_name = f'__wrapped_{func}'
|
||||
globs.update({f'__serialisation_{func}': getattr(openllm.serialisation, func, None), impl_name: impl})
|
||||
cached_func_name = f'_cached_{cls.__name__}_func'
|
||||
func_call = f"_impl_{cls.__name__}_{func}={cached_func_name} if {cached_func_name} is not _cached_LLMSerialisation_get('{func}') else __serialisation_{func}"
|
||||
lines.extend([f'{cached_func_name}=cls.{func}', func_call, _setattr_class(func, f'{impl_name}(_impl_{cls.__name__}_{func})')])
|
||||
|
||||
interface_anns = codegen.get_annotations(LLMInterface)
|
||||
|
||||
# cached attribute initialisation
|
||||
def dunder_cached(key: str) -> str:
|
||||
return f'__llm_{key}__'
|
||||
|
||||
st_attr = {'model', 'tokenizer', 'adapter_map'}
|
||||
lines.extend([_setattr_class(dunder_cached(v), None) for v in st_attr])
|
||||
|
||||
# boolean for better LLM implementation resolver
|
||||
def dunder_support(key: str) -> str:
|
||||
return f'__llm_supports_{key}__'
|
||||
|
||||
bool_attr = {it[15:-2] for it in interface_anns if it.startswith('__llm_supports_')}
|
||||
lines.extend([_setattr_class(dunder_support(fn), f"cls.{fn} is not _cached_LLMFunction_get('{fn}')") for fn in bool_attr])
|
||||
|
||||
return codegen.generate_function(cls, '__assign_llm_attr', lines, args=('cls', *args), globs=globs, annotations={'cls': 't.Type[LLM]', 'return': None})
|
||||
@@ -1,309 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
from enum import IntEnum
|
||||
from enum import auto
|
||||
|
||||
import attr
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm_core
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
class SeparatorStyle(IntEnum):
|
||||
'''Separator styles.'''
|
||||
|
||||
# Generic separator styles for chat models
|
||||
ADD_COLON_SINGLE = auto()
|
||||
ADD_COLON_TWO = auto()
|
||||
ADD_COLON_SPACE_SINGLE = auto()
|
||||
NO_COLON_SINGLE = auto()
|
||||
NO_COLON_TWO = auto()
|
||||
ADD_NEW_LINE_SINGLE = auto()
|
||||
|
||||
# Special separator styles for specific chat models in OpenLLM
|
||||
LLAMA = auto()
|
||||
CHATGLM = auto()
|
||||
DOLLY = auto()
|
||||
MPT = auto()
|
||||
STARCODER = auto()
|
||||
|
||||
@attr.define
|
||||
class Conversation:
|
||||
'''A class that manages prompt templates and keeps all conversation history.'''
|
||||
|
||||
# The name of this template
|
||||
name: str
|
||||
# The template of the system prompt
|
||||
system_template: str = '{system_message}'
|
||||
# The system message
|
||||
system_message: str = ''
|
||||
# The names of two roles
|
||||
roles: t.Tuple[str, str] = ('User', 'Assistant')
|
||||
# All messages. Each item is (role, message).
|
||||
messages: t.List[t.List[str]] = attr.Factory(list)
|
||||
# The number of few shot examples
|
||||
offset: int = 0
|
||||
# The separator style and configurations
|
||||
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
|
||||
sep: str = '\n'
|
||||
sep2: str = ''
|
||||
# Stop criteria (the default one is EOS token)
|
||||
stop_str: t.Union[str, t.List[str]] = ''
|
||||
# Stops generation if meeting any token in this list
|
||||
stop_token_ids: t.List[int] = []
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
'''Get the prompt for generation.'''
|
||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||
|
||||
# Generic separator styles for chat models
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: # Role with colon
|
||||
ret = system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ': ' + message + self.sep
|
||||
else:
|
||||
ret += role + ':'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: # Role with colon, two different separators for two roles
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = system_prompt + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ': ' + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ':'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: # Add a space after colon
|
||||
ret = system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ': ' + message + self.sep
|
||||
else:
|
||||
ret += role + ': ' # must be end with a space
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: # Add a new line after role
|
||||
ret = '' if system_prompt == '' else system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + '\n' + message + self.sep
|
||||
else:
|
||||
ret += role + '\n'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: # No colon
|
||||
ret = system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + message + self.sep
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.NO_COLON_TWO: # No colon, two different separators for two roles
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + message + seps[i % 2]
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
# Special separator styles for specific chat models
|
||||
elif self.sep_style == SeparatorStyle.LLAMA:
|
||||
seps = [self.sep, self.sep2]
|
||||
if self.system_message:
|
||||
ret = system_prompt
|
||||
else:
|
||||
ret = '<s>[INST] '
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
tag = self.roles[i % 2]
|
||||
if message:
|
||||
if i == 0:
|
||||
ret += message + ' '
|
||||
else:
|
||||
ret += tag + ' ' + message + seps[i % 2]
|
||||
else:
|
||||
ret += tag
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.CHATGLM:
|
||||
round_add_n = 1 if self.name == 'chatglm2' else 0
|
||||
if system_prompt:
|
||||
ret = system_prompt + self.sep
|
||||
else:
|
||||
ret = ''
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if i % 2 == 0:
|
||||
ret += f'[Round {i//2 + round_add_n}]{self.sep}'
|
||||
if message:
|
||||
ret += f'{role}:{message}{self.sep}'
|
||||
else:
|
||||
ret += f'{role}:'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.DOLLY:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ':\n' + message + seps[i % 2]
|
||||
if i % 2 == 1:
|
||||
ret += '\n\n'
|
||||
else:
|
||||
ret += role + ':\n'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.MPT:
|
||||
if system_prompt:
|
||||
ret = f'<|im_start|>system\n{system_prompt}<|im_end|>{self.sep}'
|
||||
else:
|
||||
ret = ''
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += f'<|im_start|>{role}\n{message}<|im_end|>{self.sep}'
|
||||
else:
|
||||
ret += f'{role}:'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.STARCODER:
|
||||
if system_prompt:
|
||||
ret = f'<|system|>\n{system_prompt}<|end|>{self.sep}'
|
||||
else:
|
||||
ret = ''
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += f'{role}\n{message}<|end|>{self.sep}'
|
||||
else:
|
||||
ret += f'{role}:'
|
||||
else:
|
||||
raise ValueError(f'Invalid style: {self.sep_style}')
|
||||
return ret
|
||||
|
||||
def set_system_message(self, system_message: str) -> None:
|
||||
_object_setattr(self, 'system_message', system_message)
|
||||
|
||||
def append_message(self, role: str, message: str) -> None:
|
||||
'''Append a new message.'''
|
||||
self.messages.append([role, message])
|
||||
|
||||
def update_last_message(self, message: str) -> None:
|
||||
'''Update the last output.
|
||||
|
||||
The last message is typically set to be None when constructing the prompt,
|
||||
so we need to update it in-place after getting the response from a model.
|
||||
'''
|
||||
self.messages[-1][1] = message
|
||||
|
||||
def to_openai_api_messages(self) -> t.List[t.Dict[str, str]]:
|
||||
'''Convert the conversation to OpenAI chat completion format.'''
|
||||
ret = [{'role': 'system', 'content': self.system_message}]
|
||||
|
||||
for i, (_, msg) in enumerate(self.messages[self.offset:]):
|
||||
if i % 2 == 0:
|
||||
ret.append({'role': 'user', 'content': msg})
|
||||
elif msg is not None:
|
||||
ret.append({'role': 'assistant', 'content': msg})
|
||||
return ret
|
||||
|
||||
def copy(self) -> Conversation:
|
||||
return Conversation(name=self.name,
|
||||
system_template=self.system_template,
|
||||
system_message=self.system_message,
|
||||
roles=self.roles,
|
||||
messages=self.messages,
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
stop_str=self.stop_str,
|
||||
stop_token_ids=self.stop_token_ids)
|
||||
|
||||
# A global registry for all conversation templates for OpenLLM models
|
||||
conv_templates: t.Dict[str, Conversation] = {}
|
||||
|
||||
def register_conv_template(template: Conversation) -> None:
|
||||
'''Register a new conversation template.'''
|
||||
conv_templates[template.name] = template
|
||||
|
||||
def get_conv_template(name: str, llm_config: openllm_core.LLMConfig) -> Conversation:
|
||||
if name not in conv_templates: raise ValueError(f'Failed to find conversation templates for {name}')
|
||||
template = conv_templates[name].copy()
|
||||
if hasattr(llm_config, 'default_system_message'): template.set_system_message(llm_config.default_system_message)
|
||||
return template
|
||||
|
||||
# Raw template
|
||||
register_conv_template(Conversation(name='raw', system_message='', roles=('', ''), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep=''))
|
||||
|
||||
# Llama template
|
||||
# source: https://huggingface.co/blog/codellama#conversational-instructions
|
||||
register_conv_template(
|
||||
Conversation(name='llama', system_template='<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2=' </s><s>',
|
||||
))
|
||||
|
||||
# ChatGLM template
|
||||
register_conv_template(Conversation(name='chatglm', roles=('问', '答'), sep_style=SeparatorStyle.CHATGLM, sep='\n',))
|
||||
|
||||
# Dolly-v2 template
|
||||
register_conv_template(
|
||||
Conversation(name='dolly_v2',
|
||||
system_message='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n',
|
||||
roles=('### Instruction', '### Response'),
|
||||
sep_style=SeparatorStyle.DOLLY,
|
||||
sep='\n\n',
|
||||
sep2='### End',
|
||||
))
|
||||
|
||||
# Falcon template
|
||||
register_conv_template(
|
||||
# source: https://huggingface.co/tiiuae/falcon-7b-instruct/discussions/1
|
||||
Conversation(name='falcon', roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, # No space after colon
|
||||
sep='\n',
|
||||
))
|
||||
|
||||
# Flan-T5 default template
|
||||
register_conv_template(
|
||||
# source: https://www.philschmid.de/fine-tune-flan-t5
|
||||
# No specific template found, but seems to have the same dialogue style
|
||||
Conversation(name='flan-t5', system_message='', roles=('User', 'Assistant'), sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'))
|
||||
|
||||
# GPT-NeoX default template
|
||||
register_conv_template(
|
||||
# source: https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B
|
||||
# Don't know if GPT-NeoX-20B is trained on any chat prompt template
|
||||
Conversation(name='gpt-neox', system_message='', roles=('<human>', '<bot>'), sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, sep='\n'))
|
||||
|
||||
# MPT template
|
||||
register_conv_template(
|
||||
# source: https://huggingface.co/TheBloke/mpt-30B-chat-GGML/discussions/4
|
||||
Conversation(name='mpt', roles=('user', 'assistant'), messages=[], sep_style=SeparatorStyle.MPT, sep='\n'))
|
||||
|
||||
# OPT template (No reference for OPT found)
|
||||
register_conv_template(Conversation(name='opt', roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'))
|
||||
|
||||
# StableLM default template
|
||||
register_conv_template(
|
||||
Conversation(name='stablelm',
|
||||
system_template='<|SYSTEM|>{system_message}',
|
||||
system_message='''# StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
''',
|
||||
roles=('<|USER|>', '<|ASSISTANT|>'),
|
||||
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
||||
sep='',
|
||||
stop_token_ids=[50278, 50279, 50277, 1, 0],
|
||||
))
|
||||
|
||||
# StarCoder default template
|
||||
register_conv_template(
|
||||
# source: https://github.com/bigcode-project/starcoder/blob/main/chat/dialogues.py
|
||||
Conversation(name='starcoder', system_message='', roles=('<|user|>', '<|assistant|>'), sep_style=SeparatorStyle.STARCODER, sep='\n'))
|
||||
|
||||
# Baichuan default template
|
||||
register_conv_template(
|
||||
# source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555
|
||||
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
|
||||
# https://github.com/baichuan-inc/Baichuan-13B/issues/25
|
||||
Conversation(name='baichuan', roles=('<reserved_102>', '<reserved_103>'), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep=''))
|
||||
|
||||
# Mistral template
|
||||
register_conv_template(Conversation(name='mistral', system_message='', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2='</s>',))
|
||||
92
openllm-python/src/openllm/_deprecated.py
Normal file
92
openllm-python/src/openllm/_deprecated.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as t
|
||||
import warnings
|
||||
|
||||
import openllm
|
||||
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core import LLMConfig
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
|
||||
from ._llm import LLMRunner
|
||||
P = ParamSpec('P')
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def _mark_deprecated(fn: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
_object_setattr(fn, '__deprecated__', True)
|
||||
return fn
|
||||
|
||||
@_mark_deprecated
|
||||
def Runner(model_name: str,
|
||||
ensure_available: bool = False,
|
||||
init_local: bool = False,
|
||||
backend: LiteralBackend | None = None,
|
||||
llm_config: LLMConfig | None = None,
|
||||
**attrs: t.Any) -> LLMRunner[t.Any, t.Any]:
|
||||
'''Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'.
|
||||
|
||||
> [!WARNING]
|
||||
> This method is now deprecated and in favor of 'openllm.LLM.runner'
|
||||
|
||||
```python
|
||||
runner = openllm.Runner("dolly-v2")
|
||||
|
||||
@svc.on_startup
|
||||
def download():
|
||||
runner.download_model()
|
||||
```
|
||||
|
||||
if `init_local=True` (For development workflow), it will also enable `ensure_available`.
|
||||
Default value of `ensure_available` is None. If set then use that given value, otherwise fallback to the aforementioned behaviour.
|
||||
|
||||
Args:
|
||||
model_name: Supported model name from 'openllm models'
|
||||
ensure_available: If True, it will download the model if it is not available. If False, it will skip downloading the model.
|
||||
If False, make sure the model is available locally.
|
||||
backend: The given Runner implementation one choose for this Runner. If `OPENLLM_BACKEND` is set, it will respect it.
|
||||
llm_config: Optional ``openllm.LLMConfig`` to initialise this ``openllm.LLMRunner``.
|
||||
init_local: If True, it will initialize the model locally. This is useful if you want to run the model locally. (Symmetrical to bentoml.Runner.init_local())
|
||||
**attrs: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs behaviour
|
||||
'''
|
||||
from ._llm import LLM
|
||||
if llm_config is None: llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
model_id = attrs.get('model_id') or llm_config['env']['model_id_value']
|
||||
_RUNNER_MSG = f'''\
|
||||
Using 'openllm.Runner' is now deprecated. Make sure to switch to the following syntax:
|
||||
|
||||
```python
|
||||
llm = openllm.LLM('{model_id}')
|
||||
|
||||
svc = bentoml.Service('...', runners=[llm.runner])
|
||||
|
||||
@svc.api(...)
|
||||
async def chat(input: str) -> str:
|
||||
async for it in llm.generate_iterator(input): print(it)
|
||||
```
|
||||
'''
|
||||
warnings.warn(_RUNNER_MSG, DeprecationWarning, stacklevel=2)
|
||||
attrs.update({
|
||||
'model_id': model_id,
|
||||
'quantize': llm_config['env']['quantize_value'],
|
||||
'serialisation': first_not_none(attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']),
|
||||
'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None),
|
||||
'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None),
|
||||
})
|
||||
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'))
|
||||
if init_local: ensure_available = True
|
||||
llm = LLM[t.Any, t.Any](backend=backend, llm_config=llm_config, **attrs)
|
||||
if ensure_available: llm.save_pretrained()
|
||||
if init_local: llm.runner.init_local(quiet=True)
|
||||
return llm.runner
|
||||
|
||||
_DEPRECATED = {k: v for k, v in locals().items() if getattr(v, '__deprecated__', False)}
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(_DEPRECATED.keys())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,8 @@ import transformers
|
||||
|
||||
from openllm_core._typing_compat import LiteralQuantise
|
||||
from openllm_core._typing_compat import overload
|
||||
from openllm_core.exceptions import MissingDependencyError
|
||||
from openllm_core.utils import is_autoawq_available
|
||||
from openllm_core.utils import is_autogptq_available
|
||||
from openllm_core.utils import is_bitsandbytes_available
|
||||
from openllm_core.utils import is_optimum_supports_gptq
|
||||
@@ -20,25 +22,36 @@ if t.TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal['int8', 'int4'], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['int8', 'int4'], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal['gptq'], **attrs: t.Any) -> tuple[transformers.GPTQConfig, DictStrAny]:
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['gptq'], **attrs: t.Any) -> tuple[transformers.GPTQConfig, DictStrAny]:
|
||||
...
|
||||
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: LiteralQuantise, **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | transformers.GPTQConfig, DictStrAny]:
|
||||
@overload
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['awq'], **attrs: t.Any) -> tuple[transformers.AwqConfig, DictStrAny]:
|
||||
...
|
||||
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise,
|
||||
**attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig, DictStrAny]:
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop('llm_int8_threshhold', 6.0)
|
||||
int8_enable_fp32_cpu_offload = attrs.pop('llm_int8_enable_fp32_cpu_offload', False)
|
||||
int8_skip_modules: list[str] | None = attrs.pop('llm_int8_skip_modules', None)
|
||||
int8_has_fp16_weight = attrs.pop('llm_int8_has_fp16_weight', False)
|
||||
|
||||
# shared arguments for gptq and awq
|
||||
bits = attrs.pop('bits', 4)
|
||||
group_size = attrs.pop('group_size', 128)
|
||||
|
||||
def create_awq_config() -> transformers.AwqConfig:
|
||||
zero_point = attrs.pop('zero_point', True)
|
||||
return transformers.AwqConfig(bits=bits, group_size=group_size, zero_point=zero_point)
|
||||
|
||||
def create_gptq_config() -> transformers.GPTQConfig:
|
||||
gptq_bits = attrs.pop('bits', 4)
|
||||
gptq_tokenizer = attrs.pop('tokenizer', None)
|
||||
gptq_tokenizer = attrs.pop('tokenizer', self.model_id)
|
||||
gptq_dataset = attrs.pop('dataset', 'c4')
|
||||
gptq_group_size = attrs.pop('group_size', 128)
|
||||
gptq_damp_percent = attrs.pop('damp_percent', 0.1)
|
||||
gptq_desc_act = attrs.pop('desc_act', False)
|
||||
gptq_sym = attrs.pop('sym', True)
|
||||
@@ -50,10 +63,10 @@ def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: LiteralQua
|
||||
gptq_batch_size = attrs.pop('batch_size', 1)
|
||||
gptq_pad_token_id = attrs.pop('pad_token_id', None)
|
||||
gptq_disable_exllama = attrs.pop('disable_exllama', False)
|
||||
return transformers.GPTQConfig(bits=gptq_bits,
|
||||
return transformers.GPTQConfig(bits=bits,
|
||||
tokenizer=gptq_tokenizer,
|
||||
dataset=gptq_dataset,
|
||||
group_size=gptq_group_size,
|
||||
group_size=group_size,
|
||||
damp_percent=gptq_damp_percent,
|
||||
desc_act=gptq_desc_act,
|
||||
sym=gptq_sym,
|
||||
@@ -67,25 +80,22 @@ def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: LiteralQua
|
||||
disable_exllama=gptq_disable_exllama)
|
||||
|
||||
def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig:
|
||||
if int8_skip_modules is None: int8_skip_modules = []
|
||||
if 'lm_head' not in int8_skip_modules and cls.config_class.__openllm_model_type__ == 'causal_lm':
|
||||
logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__)
|
||||
int8_skip_modules.append('lm_head')
|
||||
# if int8_skip_modules is None: int8_skip_modules = []
|
||||
# if 'lm_head' not in int8_skip_modules and self.config_class.__openllm_model_type__ == 'causal_lm':
|
||||
# logger.debug("Skipping 'lm_head' for quantization for %s", self.__name__)
|
||||
# int8_skip_modules.append('lm_head')
|
||||
return transformers.BitsAndBytesConfig(load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
|
||||
llm_int8_threshhold=int8_threshold,
|
||||
llm_int8_skip_modules=int8_skip_modules,
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight,
|
||||
)
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight)
|
||||
|
||||
# 4 bit configuration
|
||||
int4_compute_dtype = attrs.pop('bnb_4bit_compute_dtype', torch.bfloat16)
|
||||
int4_quant_type = attrs.pop('bnb_4bit_quant_type', 'nf4')
|
||||
int4_use_double_quant = attrs.pop('bnb_4bit_use_double_quant', True)
|
||||
|
||||
# NOTE: Quantization setup
|
||||
# quantize is a openllm.LLM feature, where we can quantize the model
|
||||
# with bitsandbytes or quantization aware training.
|
||||
# NOTE: Quantization setup quantize is a openllm.LLM feature, where we can quantize the model with bitsandbytes or quantization aware training.
|
||||
if not is_bitsandbytes_available():
|
||||
raise RuntimeError("Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'")
|
||||
if quantise == 'int8': quantisation_config = create_int8_config(int8_skip_modules)
|
||||
@@ -96,12 +106,15 @@ def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: LiteralQua
|
||||
bnb_4bit_use_double_quant=int4_use_double_quant)
|
||||
elif quantise == 'gptq':
|
||||
if not is_autogptq_available() or not is_optimum_supports_gptq():
|
||||
logger.warning(
|
||||
"'quantize=\"gptq\"' requires 'auto-gptq' and 'optimum>=0.12' to be installed (not available with local environment). Make sure to have 'auto-gptq' available locally: 'pip install \"openllm[gptq]\"'. OpenLLM will fallback to int8 with bitsandbytes."
|
||||
)
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
raise MissingDependencyError(
|
||||
"'quantize=\"gptq\"' requires 'auto-gptq' and 'optimum>=0.12' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[gptq]\"'")
|
||||
else:
|
||||
quantisation_config = create_gptq_config()
|
||||
elif quantise == 'awq':
|
||||
if not is_autoawq_available():
|
||||
raise MissingDependencyError("quantize='awq' requires 'auto-awq' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[awq]\"'.")
|
||||
else:
|
||||
quantisation_config = create_awq_config()
|
||||
else:
|
||||
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantise} instead.")
|
||||
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq', 'awq'], got {quantise} instead.")
|
||||
return quantisation_config, attrs
|
||||
|
||||
192
openllm-python/src/openllm/_runners.py
Normal file
192
openllm-python/src/openllm/_runners.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
import gc
|
||||
import os
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
import torch
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from openllm.exceptions import OpenLLMException
|
||||
from openllm_core._schemas import CompletionChunk
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
from openllm_core.utils import device_count
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import get_debug_mode
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import vllm
|
||||
|
||||
from openllm_core._schemas import FinishReason
|
||||
else:
|
||||
vllm = openllm.utils.LazyLoader('vllm', globals(), 'vllm')
|
||||
|
||||
_DEFAULT_TOKENIZER = 'hf-internal-testing/llama-tokenizer'
|
||||
|
||||
__all__ = ['runnable']
|
||||
|
||||
def runnable(backend: LiteralBackend | None = None) -> type[bentoml.Runnable]:
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if is_vllm_available() else 'pt'))
|
||||
return vLLMRunnable if backend == 'vllm' else PyTorchRunnable
|
||||
|
||||
class vLLMRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm: openllm.LLM[M, T]) -> None:
|
||||
self.config = llm.config
|
||||
num_gpus, dev = 1, device_count()
|
||||
if dev >= 2: num_gpus = min(dev // 2 * 2, dev)
|
||||
quantization = None
|
||||
if llm._quantise and llm._quantise == 'awq': quantization = llm._quantise
|
||||
try:
|
||||
self.model = vllm.AsyncLLMEngine.from_engine_args(
|
||||
vllm.AsyncEngineArgs(model=llm.bentomodel.path,
|
||||
tokenizer=llm.bentomodel.path,
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus,
|
||||
dtype='auto',
|
||||
quantization=quantization,
|
||||
disable_log_requests=not get_debug_mode(),
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False))
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from err
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self,
|
||||
prompt_token_ids: list[int],
|
||||
request_id: str,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any) -> t.AsyncGenerator[str, None]:
|
||||
if adapter_name is not None: raise NotImplementedError('Adapter is not supported with vLLM.')
|
||||
stop_: set[str] = set()
|
||||
if isinstance(stop, str) and stop != '': stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable): stop_.update(stop)
|
||||
|
||||
temperature = attrs.pop('temperature', self.config['temperature'])
|
||||
top_p = attrs.pop('top_p', self.config['top_p'])
|
||||
if temperature <= 1e-5: top_p = 1.0
|
||||
sampling_params = self.config.model_construct_env(stop=list(stop_), temperature=temperature, top_p=top_p, **attrs).to_sampling_config()
|
||||
|
||||
async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids):
|
||||
# XXX: Need to write a hook for serialisation None correctly
|
||||
if request_output.prompt_logprobs is not None: request_output.prompt_logprobs = [it if it else {} for it in request_output.prompt_logprobs]
|
||||
yield f'data: {GenerationOutput.from_vllm(request_output).model_dump_json()}\n\n'
|
||||
|
||||
class PyTorchRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm: openllm.LLM[M, T]) -> None:
|
||||
self.model = llm.model
|
||||
self.tokenizer = llm.tokenizer
|
||||
self.config = llm.config
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self,
|
||||
prompt_token_ids: list[int],
|
||||
request_id: str,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any) -> t.AsyncGenerator[str, None]:
|
||||
if adapter_name is not None: self.model.set_adapter(adapter_name)
|
||||
async for generation_output in self.forward(prompt_token_ids, request_id, stop=stop, **attrs):
|
||||
yield f'data: {generation_output.model_dump_json()}\n\n'
|
||||
|
||||
async def forward(self, prompt_token_ids: list[int], request_id: str, stop: str | t.Iterable[str] | None = None, **attrs: t.Any) -> t.AsyncGenerator[GenerationOutput, None]:
|
||||
from ._generation import is_partial_stop
|
||||
from ._generation import prepare_logits_processor
|
||||
|
||||
stop_: set[str] = set()
|
||||
if isinstance(stop, str) and stop != '': stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable): stop_.update(stop)
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
|
||||
with torch.inference_mode():
|
||||
# TODO: Support context_length check
|
||||
# context_length: int | None = attrs.pop('context_length', None)
|
||||
# if context_length is None: context_length = get_context_length(self.model.config)
|
||||
# max_src_len = context_length - config['max_new_tokens'] - 1
|
||||
# prompt_token_ids = prompt_token_ids[-max_src_len:]
|
||||
output_token_ids = list(prompt_token_ids)
|
||||
input_len = len(prompt_token_ids)
|
||||
|
||||
logits_processor = prepare_logits_processor(config)
|
||||
|
||||
past_key_values = out = token = None
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
for i in range(config['max_new_tokens']):
|
||||
if i == 0: # prefill
|
||||
out = self.model(torch.as_tensor([prompt_token_ids], device=self.device), use_cache=True)
|
||||
else: # decoding
|
||||
out = self.model(torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
if logits_processor:
|
||||
if config['repetition_penalty'] > 1.0:
|
||||
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.device)
|
||||
else:
|
||||
tmp_output_ids = None
|
||||
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
||||
else:
|
||||
last_token_logits = logits[0, -1, :]
|
||||
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
if self.device.type == 'mps': last_token_logits = last_token_logits.float().to('cpu')
|
||||
|
||||
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy
|
||||
_, indices = torch.topk(last_token_logits, 2)
|
||||
tokens = [int(index) for index in indices.tolist()]
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits, dim=-1)
|
||||
indices = torch.multinomial(probs, num_samples=2)
|
||||
tokens = [int(token) for token in indices.tolist()]
|
||||
|
||||
token = tokens[0]
|
||||
output_token_ids.append(token)
|
||||
|
||||
stopped = False
|
||||
|
||||
tmp_output_ids, rfind_start = output_token_ids[input_len:], 0
|
||||
# XXX: Move this to API server
|
||||
text = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)
|
||||
partially_stopped = False
|
||||
if stop_:
|
||||
for it in stop_:
|
||||
pos = text.rfind(it, rfind_start)
|
||||
if pos != -1:
|
||||
text, stopped = text[:pos], True
|
||||
break
|
||||
else:
|
||||
partially_stopped = is_partial_stop(text, it)
|
||||
if partially_stopped: break
|
||||
if not partially_stopped:
|
||||
yield GenerationOutput(prompt='',
|
||||
finished=False,
|
||||
outputs=[CompletionChunk(index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=None)],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id)
|
||||
if stopped: break
|
||||
else: finish_reason = 'length'
|
||||
if stopped: finish_reason = 'stop'
|
||||
yield GenerationOutput(prompt='',
|
||||
finished=True,
|
||||
outputs=[CompletionChunk(index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=finish_reason)],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id)
|
||||
|
||||
# Clean
|
||||
del past_key_values, out
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -1,23 +1,15 @@
|
||||
# mypy: disable-error-code="call-arg,misc,attr-defined,type-abstract,type-arg,valid-type,arg-type"
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
import warnings
|
||||
|
||||
import _service_vars as svars
|
||||
import orjson
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import openllm_core
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
# The following warnings from bitsandbytes, and probably not that important for users to see
|
||||
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization')
|
||||
@@ -29,193 +21,37 @@ logger = logging.getLogger(__name__)
|
||||
model = svars.model
|
||||
model_id = svars.model_id
|
||||
adapter_map = svars.adapter_map
|
||||
model_tag = svars.model_tag
|
||||
llm_config = openllm.AutoConfig.for_model(model)
|
||||
runner = openllm.Runner(model, llm_config=llm_config, model_id=model_id, ensure_available=False, adapter_map=orjson.loads(adapter_map))
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[runner])
|
||||
llm = openllm.LLM[t.Any, t.Any](model_id,
|
||||
llm_config=llm_config,
|
||||
model_tag=model_tag,
|
||||
prompt_template=openllm.utils.first_not_none(os.getenv('OPENLLM_PROMPT_TEMPLATE'), getattr(llm_config, 'default_prompt_template', None)),
|
||||
system_message=openllm.utils.first_not_none(os.getenv('OPENLLM_SYSTEM_MESSAGE'), getattr(llm_config, 'default_system_message', None)),
|
||||
serialisation=openllm.utils.first_not_none(os.getenv('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']),
|
||||
adapter_map=orjson.loads(adapter_map))
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[llm.runner])
|
||||
|
||||
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': None})
|
||||
llm_model_class = openllm.GenerationInput.from_llm_config(llm_config)
|
||||
|
||||
@svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)}))
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerateOutput:
|
||||
echo = input_dict.pop('echo', False)
|
||||
qa_inputs = openllm.GenerateInput.from_llm_config(llm_config)(**input_dict)
|
||||
config = qa_inputs.llm_config.model_dump()
|
||||
if runner.backend == 'vllm':
|
||||
async for output in runner.vllm_generate.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, request_id=openllm_core.utils.gen_random_uuid(), **config):
|
||||
responses = output
|
||||
if responses is None: raise ValueError("'responses' should not be None.")
|
||||
else:
|
||||
responses = await runner.generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, **config)
|
||||
return openllm.GenerateOutput(responses=responses, configuration=config)
|
||||
@svc.api(route='/v1/generate', input=bentoml.io.JSON.from_sample(llm_model_class.examples().model_dump()), output=bentoml.io.JSON.from_sample(openllm.GenerationOutput.examples().model_dump()))
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
return await llm.generate(**llm_model_class(**input_dict).model_dump())
|
||||
|
||||
@svc.api(route='/v1/generate_stream', input=_JsonInput, output=bentoml.io.Text(content_type='text/event-stream'))
|
||||
@svc.api(route='/v1/generate_stream', input=bentoml.io.JSON.from_sample(llm_model_class.examples().model_dump()), output=bentoml.io.Text(content_type='text/event-stream'))
|
||||
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:
|
||||
echo = input_dict.pop('echo', False)
|
||||
qa_inputs = openllm.GenerateInput.from_llm_config(llm_config)(**input_dict)
|
||||
if runner.backend == 'vllm':
|
||||
return runner.vllm_generate_iterator.async_stream(qa_inputs.prompt,
|
||||
adapter_name=qa_inputs.adapter_name,
|
||||
echo=echo,
|
||||
request_id=openllm_core.utils.gen_random_uuid(),
|
||||
**qa_inputs.llm_config.model_dump())
|
||||
else:
|
||||
return runner.generate_iterator.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, **qa_inputs.llm_config.model_dump())
|
||||
async for it in llm.generate_iterator(**llm_model_class(**input_dict).model_dump()):
|
||||
yield f'data: {it.model_dump_json()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
@svc.api(route='v1/completions',
|
||||
input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.CompletionRequest(prompt='What is 1+1?', model=runner.llm_type))),
|
||||
output=bentoml.io.Text())
|
||||
async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
|
||||
_model = input_dict.get('model', None)
|
||||
if _model != runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
|
||||
prompt = input_dict.pop('prompt', None)
|
||||
if prompt is None: raise ValueError("'prompt' should not be None.")
|
||||
stream = input_dict.pop('stream', False)
|
||||
config = {
|
||||
'max_new_tokens': input_dict.pop('max_tokens', llm_config['max_new_tokens']),
|
||||
'temperature': input_dict.pop('temperature', llm_config['temperature']),
|
||||
'top_p': input_dict.pop('top_p', llm_config['top_p']),
|
||||
'n': input_dict.pop('n', llm_config['n']),
|
||||
'logprobs': input_dict.pop('logprobs', llm_config['logprobs']),
|
||||
'echo': input_dict.pop('echo', False),
|
||||
'stop': input_dict.pop('stop', llm_config['stop']),
|
||||
'presence_penalty': input_dict.pop('presence_penalty', llm_config['presence_penalty']),
|
||||
'frequency_penalty': input_dict.pop('frequency_penalty', llm_config['frequency_penalty']),
|
||||
'best_of': input_dict.pop('best_of', llm_config['best_of']),
|
||||
}
|
||||
|
||||
async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t.AsyncGenerator[str, None]:
|
||||
async for response in responses:
|
||||
st = openllm.openai.CompletionResponseStream(choices=[openllm.openai.CompletionTextChoice(text=response, index=0)], model=runner.llm_type) # TODO: logprobs, finish_reason
|
||||
yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(st)).decode()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
if stream:
|
||||
ctx.response.headers['Content-Type'] = 'text/event-stream'
|
||||
if runner.backend == 'vllm':
|
||||
responses = runner.vllm_generate_iterator.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config)
|
||||
else:
|
||||
responses = runner.generate_iterator.async_stream(prompt, **config)
|
||||
return stream_response_generator(responses)
|
||||
else:
|
||||
ctx.response.headers['Content-Type'] = 'application/json'
|
||||
if runner.backend == 'vllm':
|
||||
async for output in runner.vllm_generate.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config):
|
||||
responses = output
|
||||
if responses is None: raise ValueError("'responses' should not be None.")
|
||||
else:
|
||||
responses = await runner.generate.async_run(prompt, **config)
|
||||
|
||||
return orjson.dumps(
|
||||
openllm.utils.bentoml_cattr.unstructure(
|
||||
openllm.openai.CompletionResponse(choices=[openllm.openai.CompletionTextChoice(text=response, index=i) for i, response in enumerate(responses)],
|
||||
model=runner.llm_type) # TODO: logprobs, finish_reason and usage
|
||||
)).decode()
|
||||
|
||||
@svc.api(route='/v1/chat/completions',
|
||||
input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.ChatCompletionRequest(messages=[{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello!'}], model=runner.llm_type))),
|
||||
output=bentoml.io.Text())
|
||||
async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
|
||||
_model = input_dict.get('model', None)
|
||||
if _model != runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
|
||||
prompt = openllm.openai.messages_to_prompt(input_dict['messages'], model, llm_config)
|
||||
stream = input_dict.pop('stream', False)
|
||||
config = {
|
||||
'temperature': input_dict.pop('temperature', llm_config['temperature']),
|
||||
'top_p': input_dict.pop('top_p', llm_config['top_p']),
|
||||
'n': input_dict.pop('n', llm_config['n']),
|
||||
'echo': input_dict.pop('echo', False),
|
||||
'stop': input_dict.pop('stop', llm_config['stop']),
|
||||
'max_new_tokens': input_dict.pop('max_tokens', llm_config['max_new_tokens']),
|
||||
'presence_penalty': input_dict.pop('presence_penalty', llm_config['presence_penalty']),
|
||||
'frequency_penalty': input_dict.pop('frequency_penalty', llm_config['frequency_penalty']),
|
||||
'_format_chat_template': True,
|
||||
}
|
||||
|
||||
async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t.AsyncGenerator[str, None]:
|
||||
async for response in responses:
|
||||
st = openllm.openai.ChatCompletionResponseStream(
|
||||
choices=[openllm.openai.ChatCompletionStreamChoice(index=0, delta=openllm.openai.Message(role='assistant', content=response), finish_reason=None)], model=runner.llm_type)
|
||||
yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(st)).decode()}\n\n'
|
||||
final = openllm.openai.ChatCompletionResponseStream(
|
||||
choices=[openllm.openai.ChatCompletionStreamChoice(index=0, delta=openllm.openai.Message(role='assistant', content=''), finish_reason='stop')], model=runner.llm_type)
|
||||
yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(final)).decode()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
if stream:
|
||||
ctx.response.headers['Content-Type'] = 'text/event-stream'
|
||||
if runner.backend == 'vllm':
|
||||
responses = runner.vllm_generate_iterator.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config)
|
||||
else:
|
||||
responses = runner.generate_iterator.async_stream(prompt, **config)
|
||||
return stream_response_generator(responses)
|
||||
else:
|
||||
ctx.response.headers['Content-Type'] = 'application/json'
|
||||
if runner.backend == 'vllm':
|
||||
async for output in runner.vllm_generate.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config):
|
||||
responses = output
|
||||
if responses is None: raise ValueError("'responses' should not be None.")
|
||||
else:
|
||||
responses = await runner.generate.async_run(prompt, **config)
|
||||
return orjson.dumps(
|
||||
openllm.utils.bentoml_cattr.unstructure(
|
||||
openllm.openai.ChatCompletionResponse(
|
||||
choices=[openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)],
|
||||
model=runner.llm_type) # TODO: logprobs, finish_reason and usage
|
||||
)).decode('utf-8')
|
||||
|
||||
def models_v1(_: Request) -> Response:
|
||||
return JSONResponse(openllm.utils.bentoml_cattr.unstructure(openllm.openai.ModelList(data=[openllm.openai.ModelCard(id=runner.llm_type)])), status_code=200)
|
||||
|
||||
openai_app = Starlette(debug=True, routes=[Route('/models', models_v1, methods=['GET'])])
|
||||
svc.mount_asgi_app(openai_app, path='/v1')
|
||||
|
||||
@svc.api(route='/v1/metadata',
|
||||
input=bentoml.io.Text(),
|
||||
output=bentoml.io.JSON.from_sample({
|
||||
'model_id': runner.llm.model_id,
|
||||
'timeout': 3600,
|
||||
'model_name': llm_config['model_name'],
|
||||
'backend': runner.backend,
|
||||
'configuration': llm_config.model_dump(flatten=True),
|
||||
'supports_hf_agent': runner.supports_hf_agent,
|
||||
'prompt_template': runner.prompt_template,
|
||||
'system_message': runner.system_message,
|
||||
}))
|
||||
@svc.api(route='/v1/metadata', input=bentoml.io.Text(), output=bentoml.io.JSON.from_sample(openllm.MetadataOutput.examples(llm).model_dump()))
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return openllm.MetadataOutput(timeout=llm_config['timeout'],
|
||||
model_name=llm_config['model_name'],
|
||||
backend=llm_config['env']['backend_value'],
|
||||
model_id=runner.llm.model_id,
|
||||
backend=llm.__llm_backend__,
|
||||
model_id=llm.model_id,
|
||||
configuration=llm_config.model_dump_json().decode(),
|
||||
supports_hf_agent=runner.supports_hf_agent,
|
||||
prompt_template=runner.prompt_template,
|
||||
system_message=runner.system_message,
|
||||
)
|
||||
prompt_template=llm.runner.prompt_template,
|
||||
system_message=llm.runner.system_message)
|
||||
|
||||
if runner.supports_hf_agent:
|
||||
|
||||
async def hf_agent(request: Request) -> Response:
|
||||
json_str = await request.body()
|
||||
try:
|
||||
input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), openllm.HfAgentInput)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise openllm.exceptions.OpenLLMException(f'Invalid JSON input received: {err}') from None
|
||||
stop = input_data.parameters.pop('stop', ['\n'])
|
||||
try:
|
||||
return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200)
|
||||
except NotImplementedError:
|
||||
return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500)
|
||||
|
||||
hf_app = Starlette(debug=True, routes=[Route('/agent', hf_agent, methods=['POST'])])
|
||||
svc.mount_asgi_app(hf_app, path='/hf')
|
||||
|
||||
# general metadata app
|
||||
async def list_adapter_v1(_: Request) -> Response:
|
||||
res: dict[str, t.Any] = {}
|
||||
if runner.peft_adapters['success'] is True:
|
||||
res['result'] = {k: v.to_dict() for k, v in runner.peft_adapters['result'].items()}
|
||||
res.update({'success': runner.peft_adapters['success'], 'error_msg': runner.peft_adapters['error_msg']})
|
||||
return JSONResponse(res, status_code=200)
|
||||
|
||||
adapters_app_v1 = Starlette(debug=True, routes=[Route('/adapters', list_adapter_v1, methods=['GET'])])
|
||||
svc.mount_asgi_app(adapters_app_v1, path='/v1')
|
||||
openllm.mount_entrypoints(svc, llm) # HACK: This must always be the last line in this file, as we will do some MK for OpenAPI schema.
|
||||
|
||||
@@ -3,4 +3,5 @@ import os
|
||||
|
||||
model = os.environ['OPENLLM_MODEL'] # openllm: model name
|
||||
model_id = os.environ['OPENLLM_MODEL_ID'] # openllm: model name
|
||||
model_tag = None # openllm: model tag
|
||||
adapter_map = os.environ['OPENLLM_ADAPTER_MAP'] # openllm: model adapter map
|
||||
|
||||
@@ -2,4 +2,5 @@ from __future__ import annotations
|
||||
|
||||
model = '{__model_name__}' # openllm: model name
|
||||
model_id = '{__model_id__}' # openllm: model id
|
||||
model_tag = '{__model_tag__}' # openllm: model tag
|
||||
adapter_map = '''{__model_adapter_map__}''' # openllm: model adapter map
|
||||
|
||||
@@ -64,7 +64,7 @@ def build_editable(path: str, package: t.Literal['openllm', 'openllm_core', 'ope
|
||||
return builder.build('wheel', path, config_settings={'--global-option': '--quiet'})
|
||||
raise RuntimeError('Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.')
|
||||
|
||||
def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_dependencies: tuple[str, ...] | None = None, adapter_map: dict[str, str | None] | None = None,) -> PythonOptions:
|
||||
def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_dependencies: tuple[str, ...] | None = None, adapter_map: dict[str, str] | None = None,) -> PythonOptions:
|
||||
packages = ['openllm', 'scipy'] # apparently bnb misses this one
|
||||
if adapter_map is not None: packages += ['openllm[fine-tune]']
|
||||
# NOTE: add openllm to the default dependencies
|
||||
@@ -79,32 +79,10 @@ def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_d
|
||||
packages.append(f"bentoml>={'.'.join([str(i) for i in openllm_core.utils.pkg.pkg_version_info('bentoml')])}")
|
||||
|
||||
env = llm.config['env']
|
||||
backend_envvar = env['backend_value']
|
||||
if backend_envvar == 'flax':
|
||||
if not openllm_core.utils.is_flax_available():
|
||||
raise ValueError(f"Flax is not available, while {env.backend} is set to 'flax'")
|
||||
packages.extend([importlib.metadata.version('flax'), importlib.metadata.version('jax'), importlib.metadata.version('jaxlib')])
|
||||
elif backend_envvar == 'tf':
|
||||
if not openllm_core.utils.is_tf_available():
|
||||
raise ValueError(f"TensorFlow is not available, while {env.backend} is set to 'tf'")
|
||||
candidates = ('tensorflow', 'tensorflow-cpu', 'tensorflow-gpu', 'tf-nightly', 'tf-nightly-cpu', 'tf-nightly-gpu', 'intel-tensorflow', 'intel-tensorflow-avx512', 'tensorflow-rocm',
|
||||
'tensorflow-macos',
|
||||
)
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
for candidate in candidates:
|
||||
try:
|
||||
pkgver = importlib.metadata.version(candidate)
|
||||
if pkgver == candidate: packages.extend(['tensorflow'])
|
||||
else:
|
||||
_tf_version = importlib.metadata.version(candidate)
|
||||
packages.extend([f'tensorflow>={_tf_version}'])
|
||||
break
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pass # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
|
||||
else:
|
||||
if not openllm_core.utils.is_torch_available():
|
||||
raise ValueError('PyTorch is not available. Make sure to have it locally installed.')
|
||||
packages.extend([f'torch>={importlib.metadata.version("torch")}'])
|
||||
env['backend_value']
|
||||
if not openllm_core.utils.is_torch_available():
|
||||
raise ValueError('PyTorch is not available. Make sure to have it locally installed.')
|
||||
packages.extend([f'torch>={importlib.metadata.version("torch")}'])
|
||||
wheels: list[str] = []
|
||||
built_wheels: list[str |
|
||||
None] = [build_editable(llm_fs.getsyspath('/'), t.cast(t.Literal['openllm', 'openllm_core', 'openllm_client'], p)) for p in ('openllm_core', 'openllm_client', 'openllm')]
|
||||
@@ -115,9 +93,9 @@ def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_d
|
||||
lock_packages=False,
|
||||
extra_index_url=['https://download.pytorch.org/whl/cu118', 'https://huggingface.github.io/autogptq-index/whl/cu118/'])
|
||||
|
||||
def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: float, quantize: LiteralString | None, adapter_map: dict[str, str | None] | None,
|
||||
dockerfile_template: str | None, serialisation: LiteralSerialisation, container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy) -> DockerOptions:
|
||||
def construct_docker_options(llm: openllm.LLM[t.Any,
|
||||
t.Any], _: FS, workers_per_resource: float, quantize: LiteralString | None, adapter_map: dict[str, str] | None, dockerfile_template: str | None,
|
||||
serialisation: LiteralSerialisation, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy) -> DockerOptions:
|
||||
from openllm.cli._factory import parse_config_options
|
||||
environ = parse_config_options(llm.config, llm.config['timeout'], workers_per_resource, None, True, os.environ.copy())
|
||||
env: openllm_core.utils.EnvVarMixin = llm.config['env']
|
||||
@@ -141,6 +119,7 @@ def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_
|
||||
|
||||
OPENLLM_MODEL_NAME = '# openllm: model name'
|
||||
OPENLLM_MODEL_ID = '# openllm: model id'
|
||||
OPENLLM_MODEL_TAG = '# openllm: model tag'
|
||||
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
|
||||
|
||||
class ModelNameFormatter(string.Formatter):
|
||||
@@ -164,16 +143,20 @@ class ModelNameFormatter(string.Formatter):
|
||||
class ModelIdFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = '__model_id__'
|
||||
|
||||
class ModelTagFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = '__model_tag__'
|
||||
|
||||
class ModelAdapterMapFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = '__model_adapter_map__'
|
||||
|
||||
_service_file = Path(os.path.abspath(__file__)).parent.parent / '_service.py'
|
||||
_service_vars_file = Path(os.path.abspath(__file__)).parent.parent / '_service_vars_pkg.py'
|
||||
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str] | None, llm_fs: FS) -> None:
|
||||
from openllm_core.utils import DEBUG
|
||||
model_name = llm.config['model_name']
|
||||
model_id = llm.model_id
|
||||
model_tag = str(llm.tag)
|
||||
logger.debug('Generating service vars file for %s at %s (dir=%s)', model_name, '_service_vars.py', llm_fs.getsyspath('/'))
|
||||
with open(_service_vars_file.__fspath__(), 'r') as f:
|
||||
src_contents = f.readlines()
|
||||
@@ -182,6 +165,8 @@ def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | N
|
||||
src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + '\n')
|
||||
if OPENLLM_MODEL_ID in it:
|
||||
src_contents[src_contents.index(it)] = (ModelIdFormatter(model_id).vformat(it)[:-(len(OPENLLM_MODEL_ID) + 3)] + '\n')
|
||||
elif OPENLLM_MODEL_TAG in it:
|
||||
src_contents[src_contents.index(it)] = (ModelTagFormatter(model_tag).vformat(it)[:-(len(OPENLLM_MODEL_TAG) + 3)] + '\n')
|
||||
elif OPENLLM_MODEL_ADAPTER_MAP in it:
|
||||
src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + '\n')
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + ''.join(src_contents)
|
||||
@@ -200,7 +185,7 @@ def create_bento(bento_tag: bentoml.Tag,
|
||||
workers_per_resource: str | float,
|
||||
quantize: LiteralString | None,
|
||||
dockerfile_template: str | None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
container_registry: LiteralContainerRegistry = 'ecr',
|
||||
|
||||
@@ -11,6 +11,7 @@ import inflection
|
||||
import orjson
|
||||
|
||||
from bentoml_cli.utils import BentoMLCommandGroup
|
||||
from click import ClickException
|
||||
from click import shell_completion as sc
|
||||
from click.shell_completion import CompletionItem
|
||||
|
||||
@@ -28,6 +29,9 @@ from openllm_core._typing_compat import LiteralString
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import get_literal_args
|
||||
from openllm_core.utils import DEBUG
|
||||
from openllm_core.utils import check_bool_env
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
from . import termui
|
||||
|
||||
@@ -62,7 +66,6 @@ def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_res
|
||||
_bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)])
|
||||
else:
|
||||
_bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
|
||||
_bentoml_config_options_opts.append(f'runners."llm-generic-embedding".resources.cpu={openllm.get_resource({"cpu":"system"},"cpu")}')
|
||||
if cors:
|
||||
_bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
|
||||
_bentoml_config_options_opts.extend([f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])])
|
||||
@@ -84,7 +87,8 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
|
||||
if len(adapter_name) == 0: raise ClickException(f'Adapter name is required for {adapter_id}')
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0]
|
||||
return None
|
||||
|
||||
def start_command_factory(group: click.Group, model: str, _context_settings: DictStrAny | None = None, _serve_grpc: bool = False) -> click.Command:
|
||||
@@ -117,24 +121,23 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
@start_decorator(llm_config, serve_grpc=_serve_grpc)
|
||||
@click.pass_context
|
||||
def start_cmd(ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, system_message: str | None, prompt_template_file: t.IO[t.Any] | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, backend: LiteralBackend,
|
||||
serialisation: LiteralSerialisation | None, cors: bool, adapter_id: str | None, return_process: bool, **attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
|
||||
if _serialisation == 'safetensors' and quantize is not None and openllm_core.utils.check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, backend: LiteralBackend | None,
|
||||
serialisation: LiteralSerialisation | None, cors: bool, adapter_id: str | None, return_process: bool, **attrs: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
|
||||
if _serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
|
||||
termui.echo(
|
||||
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.",
|
||||
fg='yellow')
|
||||
termui.echo(f"Make sure to check out '{model_id}' repository to see if the weights is in '{_serialisation}' format if unsure.")
|
||||
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
|
||||
adapter_map: dict[str, str] | None = attrs.pop(_adapter_mapping_key, None)
|
||||
config, server_attrs = llm_config.model_validate_click(**attrs)
|
||||
server_timeout = openllm.utils.first_not_none(server_timeout, default=config['timeout'])
|
||||
server_timeout = first_not_none(server_timeout, default=config['timeout'])
|
||||
server_attrs.update({'working_dir': os.path.dirname(os.path.dirname(__file__)), 'timeout': server_timeout})
|
||||
if _serve_grpc: server_attrs['grpc_protocol_version'] = 'v1'
|
||||
# NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream.
|
||||
development = server_attrs.pop('development')
|
||||
server_attrs.setdefault('production', not development)
|
||||
wpr = openllm.utils.first_not_none(workers_per_resource, default=config['workers_per_resource'])
|
||||
wpr = first_not_none(workers_per_resource, default=config['workers_per_resource'])
|
||||
|
||||
if isinstance(wpr, str):
|
||||
if wpr == 'round_robin': wpr = 1.0
|
||||
@@ -151,7 +154,10 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
wpr = float(wpr)
|
||||
|
||||
# Create a new model env to work with the envvar during CLI invocation
|
||||
env = openllm.utils.EnvVarMixin(config['model_name'], backend, model_id=model_id or config['default_id'], quantize=quantize)
|
||||
env = openllm.utils.EnvVarMixin(config['model_name'],
|
||||
backend=openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'),
|
||||
model_id=model_id or config['default_id'],
|
||||
quantize=quantize)
|
||||
requirements = llm_config['requirements']
|
||||
if requirements is not None and len(requirements) > 0:
|
||||
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
|
||||
@@ -176,16 +182,16 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
|
||||
llm = openllm.utils.infer_auto_class(env['backend_value']).for_model(model,
|
||||
model_id=start_env[env.model_id],
|
||||
model_version=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_config=config,
|
||||
ensure_available=True,
|
||||
adapter_map=adapter_map,
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=_serialisation)
|
||||
llm = openllm.LLM[t.Any, t.Any](model_id=start_env[env.model_id],
|
||||
revision=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_config=config,
|
||||
backend=env['backend_value'],
|
||||
adapter_map=adapter_map,
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=_serialisation)
|
||||
llm.save_pretrained() # ensure_available = True
|
||||
start_env.update({env.config: llm.config.model_dump_json().decode()})
|
||||
|
||||
server = bentoml.GrpcServer('_service:svc', **server_attrs) if _serve_grpc else bentoml.HTTPServer('_service:svc', **server_attrs)
|
||||
@@ -382,8 +388,8 @@ def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
|
||||
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
|
||||
# XXX: remove the check for __args__ once we have ggml and mlc supports
|
||||
return cli_option('--backend',
|
||||
type=click.Choice(get_literal_args(LiteralBackend)[:-2]),
|
||||
default='pt',
|
||||
type=click.Choice(get_literal_args(LiteralBackend)[:2]),
|
||||
default=None,
|
||||
envvar='OPENLLM_BACKEND',
|
||||
show_envvar=True,
|
||||
help='The implementation for saving this LLM.',
|
||||
@@ -396,7 +402,7 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
|
||||
return cli_option('--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=click.Choice(['int8', 'int4', 'gptq']),
|
||||
type=click.Choice(get_literal_args(LiteralQuantise)),
|
||||
default=None,
|
||||
envvar='OPENLLM_QUANTIZE',
|
||||
show_envvar=True,
|
||||
|
||||
@@ -16,6 +16,7 @@ import openllm_core
|
||||
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm.exceptions import OpenLLMException
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
from . import termui
|
||||
from ._factory import start_command_factory
|
||||
@@ -88,9 +89,7 @@ def _start(model_name: str,
|
||||
"""
|
||||
from .entrypoint import start_command
|
||||
from .entrypoint import start_grpc_command
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
_ModelEnv = openllm_core.utils.EnvVarMixin(model_name, backend=openllm_core.utils.first_not_none(backend, default=llm_config.default_backend()), model_id=model_id, quantize=quantize)
|
||||
os.environ[_ModelEnv.backend] = _ModelEnv['backend_value']
|
||||
os.environ['OPENLLM_BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
|
||||
|
||||
args: list[str] = []
|
||||
if model_id: args.extend(['--model-id', model_id])
|
||||
@@ -218,7 +217,7 @@ def _import_model(model_name: str,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
backend: LiteralBackend = 'pt',
|
||||
backend: LiteralBackend | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
serialisation: t.Literal['legacy', 'safetensors'] | None = None,
|
||||
additional_args: t.Sequence[str] | None = None) -> bentoml.Model:
|
||||
@@ -254,7 +253,8 @@ def _import_model(model_name: str,
|
||||
from .entrypoint import import_command
|
||||
config = openllm.AutoConfig.for_model(model_name)
|
||||
_serialisation = openllm_core.utils.first_not_none(serialisation, default=config['serialisation'])
|
||||
args = [model_name, '--backend', backend, '--machine', '--serialisation', _serialisation]
|
||||
args = [model_name, '--machine', '--serialisation', _serialisation]
|
||||
if backend is not None: args.extend(['--backend', backend])
|
||||
if model_id is not None: args.append(model_id)
|
||||
if model_version is not None: args.extend(['--model-version', str(model_version)])
|
||||
if additional_args is not None: args.extend(additional_args)
|
||||
|
||||
@@ -27,9 +27,7 @@ import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
@@ -48,20 +46,11 @@ from simple_di import inject
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import openllm_core
|
||||
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelStore
|
||||
from openllm import bundle
|
||||
from openllm.exceptions import OpenLLMException
|
||||
from openllm.models.auto import CONFIG_MAPPING
|
||||
from openllm.models.auto import MODEL_FLAX_MAPPING_NAMES
|
||||
from openllm.models.auto import MODEL_MAPPING_NAMES
|
||||
from openllm.models.auto import MODEL_TF_MAPPING_NAMES
|
||||
from openllm.models.auto import MODEL_VLLM_MAPPING_NAMES
|
||||
from openllm.models.auto import AutoConfig
|
||||
from openllm.models.auto import AutoLLM
|
||||
from openllm.utils import infer_auto_class
|
||||
from openllm_core._typing_compat import Concatenate
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
@@ -70,20 +59,21 @@ from openllm_core._typing_compat import LiteralSerialisation
|
||||
from openllm_core._typing_compat import LiteralString
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import Self
|
||||
from openllm_core.utils import DEBUG
|
||||
from openllm_core.config import CONFIG_MAPPING
|
||||
from openllm_core.utils import DEBUG_ENV_VAR
|
||||
from openllm_core.utils import OPTIONAL_DEPENDENCIES
|
||||
from openllm_core.utils import QUIET_ENV_VAR
|
||||
from openllm_core.utils import EnvVarMixin
|
||||
from openllm_core.utils import LazyLoader
|
||||
from openllm_core.utils import analytics
|
||||
from openllm_core.utils import bentoml_cattr
|
||||
from openllm_core.utils import compose
|
||||
from openllm_core.utils import configure_logging
|
||||
from openllm_core.utils import converter
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import get_debug_mode
|
||||
from openllm_core.utils import get_quiet_mode
|
||||
from openllm_core.utils import is_torch_available
|
||||
from openllm_core.utils import is_vllm_available
|
||||
from openllm_core.utils import resolve_user_filepath
|
||||
from openllm_core.utils import set_debug_mode
|
||||
from openllm_core.utils import set_quiet_mode
|
||||
@@ -112,7 +102,7 @@ if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from bentoml._internal.container import DefaultBuilder
|
||||
from openllm_client._schemas import Response
|
||||
from openllm_client._schemas import StreamResponse
|
||||
from openllm_client._schemas import StreamingResponse
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry
|
||||
from openllm_core._typing_compat import LiteralContainerVersionStrategy
|
||||
else:
|
||||
@@ -347,9 +337,8 @@ _start_mapping = {
|
||||
@machine_option
|
||||
@backend_option
|
||||
@serialisation_option
|
||||
def import_command(model_name: str, model_id: str | None, converter: str | None, model_version: str | None, output: LiteralOutput, machine: bool, backend: LiteralBackend,
|
||||
quantize: LiteralQuantise | None, serialisation: LiteralSerialisation | None,
|
||||
) -> bentoml.Model:
|
||||
def import_command(model_name: str, model_id: str | None, converter: str | None, model_version: str | None, output: LiteralOutput, machine: bool, backend: LiteralBackend | None,
|
||||
quantize: LiteralQuantise | None, serialisation: LiteralSerialisation | None) -> bentoml.Model:
|
||||
"""Setup LLM interactively.
|
||||
|
||||
It accepts two positional arguments: `model_name` and `model_id`. The first name determine
|
||||
@@ -400,24 +389,19 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
|
||||
$ CONVERTER=llama2-hf openllm import llama /path/to/llama-2
|
||||
```
|
||||
"""
|
||||
llm_config = AutoConfig.for_model(model_name)
|
||||
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
|
||||
env = EnvVarMixin(model_name, backend=llm_config.default_backend(), model_id=model_id, quantize=quantize)
|
||||
backend = first_not_none(backend, default=env['backend_value'])
|
||||
llm = infer_auto_class(backend).for_model(model_name,
|
||||
model_id=env['model_id_value'],
|
||||
llm_config=llm_config,
|
||||
model_version=model_version,
|
||||
ensure_available=False,
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=_serialisation)
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
|
||||
env = EnvVarMixin(model_name, model_id=model_id, quantize=quantize)
|
||||
model_id = first_not_none(model_id, env['model_id_value'], default=llm_config['default_id'])
|
||||
backend = first_not_none(backend, env['backend_value'], default='vllm' if is_vllm_available() else 'pt')
|
||||
llm = openllm.LLM[t.Any, t.Any](model_id=model_id, llm_config=llm_config, revision=model_version, quantize=env['quantize_value'], serialisation=_serialisation, backend=backend)
|
||||
_previously_saved = False
|
||||
try:
|
||||
_ref = openllm.serialisation.get(llm)
|
||||
_previously_saved = True
|
||||
except openllm.exceptions.OpenLLMException:
|
||||
if not machine and output == 'pretty':
|
||||
msg = f"'{model_name}' {'with model_id='+ model_id if model_id is not None else ''} does not exists in local store for backend {llm.__llm_backend__}. Saving to BENTOML_HOME{' (path=' + os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..."
|
||||
msg = f"'{model_name}' with model_id='{model_id}' does not exists in local store for backend {llm.__llm_backend__}. Saving to BENTOML_HOME{' (path=' + os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..."
|
||||
termui.echo(msg, fg='yellow', nl=True)
|
||||
_ref = openllm.serialisation.get(llm, auto_import=True)
|
||||
if backend == 'pt' and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
@@ -471,11 +455,10 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
|
||||
@click.option('--force-push', default=False, is_flag=True, type=click.BOOL, help='Whether to force push.')
|
||||
@click.pass_context
|
||||
def build_command(ctx: click.Context, /, model_name: str, model_id: str | None, bento_version: str | None, overwrite: bool, output: LiteralOutput, quantize: LiteralQuantise | None,
|
||||
enable_features: tuple[str, ...] | None, workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend,
|
||||
enable_features: tuple[str, ...] | None, workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend | None,
|
||||
system_message: str | None, prompt_template_file: t.IO[t.Any] | None, machine: bool, model_version: str | None, dockerfile_template: t.TextIO | None, containerize: bool,
|
||||
push: bool, serialisation: LiteralSerialisation | None, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy,
|
||||
force_push: bool, **attrs: t.Any,
|
||||
) -> bentoml.Bento:
|
||||
force_push: bool, **attrs: t.Any) -> bentoml.Bento:
|
||||
'''Package a given models into a Bento.
|
||||
|
||||
\b
|
||||
@@ -498,9 +481,9 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
|
||||
|
||||
_previously_built = False
|
||||
|
||||
llm_config = AutoConfig.for_model(model_name)
|
||||
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
|
||||
env = EnvVarMixin(model_name, backend=backend, model_id=model_id, quantize=quantize)
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
|
||||
env = EnvVarMixin(model_name, backend=first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'), model_id=model_id or llm_config['default_id'], quantize=quantize)
|
||||
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
|
||||
|
||||
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
|
||||
@@ -509,21 +492,25 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
|
||||
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': _serialisation, env.backend: env['backend_value']})
|
||||
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
|
||||
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
|
||||
if env['backend_value']: os.environ[env.backend] = str(env['backend_value'])
|
||||
if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
|
||||
llm = infer_auto_class(env['backend_value']).for_model(model_name,
|
||||
model_id=env['model_id_value'],
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_config=llm_config,
|
||||
ensure_available=True,
|
||||
model_version=model_version,
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=_serialisation,
|
||||
**attrs)
|
||||
llm = openllm.LLM[t.Any, t.Any](model_id=env['model_id_value'] or llm_config['default_id'],
|
||||
revision=model_version,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_config=llm_config,
|
||||
backend=env['backend_value'],
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=_serialisation,
|
||||
**attrs)
|
||||
llm.save_pretrained() # ensure_available = True
|
||||
|
||||
assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything
|
||||
# FIX: This is a patch for _service_vars injection
|
||||
if 'OPENLLM_MODEL_ID' not in os.environ: os.environ['OPENLLM_MODEL_ID'] = llm.model_id
|
||||
if 'OPENLLM_ADAPTER_MAP' not in os.environ: os.environ['OPENLLM_ADAPTER_MAP'] = orjson.dumps(None).decode()
|
||||
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})
|
||||
@@ -536,13 +523,13 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
|
||||
llm_fs.writetext('Dockerfile.template', dockerfile_template.read())
|
||||
dockerfile_template_path = llm_fs.getsyspath('/Dockerfile.template')
|
||||
|
||||
adapter_map: dict[str, str | None] | None = None
|
||||
adapter_map: dict[str, str] | None = None
|
||||
if adapter_id:
|
||||
if not build_ctx: ctx.fail("'build_ctx' is required when '--adapter-id' is passsed.")
|
||||
adapter_map = {}
|
||||
for v in adapter_id:
|
||||
_adapter_id, *adapter_name = v.rsplit(':', maxsplit=1)
|
||||
name = adapter_name[0] if len(adapter_name) > 0 else None
|
||||
name = adapter_name[0] if len(adapter_name) > 0 else 'default'
|
||||
try:
|
||||
resolve_user_filepath(_adapter_id, build_ctx)
|
||||
src_folder_name = os.path.basename(_adapter_id)
|
||||
@@ -558,7 +545,7 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
|
||||
adapter_map[_adapter_id] = name
|
||||
os.environ['OPENLLM_ADAPTER_MAP'] = orjson.dumps(adapter_map).decode()
|
||||
|
||||
_bento_version = first_not_none(bento_version, default=llm.tag.version)
|
||||
_bento_version = first_not_none(bento_version, default=llm.bentomodel.tag.version)
|
||||
bento_tag = bentoml.Tag.from_taglike(f'{llm.llm_type}-service:{_bento_version}'.lower().strip())
|
||||
try:
|
||||
bento = bentoml.get(bento_tag)
|
||||
@@ -633,29 +620,17 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
|
||||
if show_available: raise click.BadOptionUsage('--show-available', "Cannot use '--show-available' with '-o porcelain' (mutually exclusive).")
|
||||
termui.echo('\n'.join(models), fg='white')
|
||||
else:
|
||||
failed_initialized: list[tuple[str, Exception]] = []
|
||||
|
||||
json_data: dict[str, dict[t.Literal['architecture', 'model_id', 'url', 'installation', 'cpu', 'gpu', 'backend'], t.Any] | t.Any] = {}
|
||||
converted: list[str] = []
|
||||
for m in models:
|
||||
config = AutoConfig.for_model(m)
|
||||
backend: tuple[str, ...] = ()
|
||||
if config['model_name'] in MODEL_MAPPING_NAMES: backend += ('pt',)
|
||||
if config['model_name'] in MODEL_FLAX_MAPPING_NAMES: backend += ('flax',)
|
||||
if config['model_name'] in MODEL_TF_MAPPING_NAMES: backend += ('tf',)
|
||||
if config['model_name'] in MODEL_VLLM_MAPPING_NAMES: backend += ('vllm',)
|
||||
config = openllm.AutoConfig.for_model(m)
|
||||
json_data[m] = {
|
||||
'architecture': config['architecture'],
|
||||
'model_id': config['model_ids'],
|
||||
'backend': backend,
|
||||
'backend': config['backend'],
|
||||
'installation': f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config['requirements'] else 'openllm',
|
||||
}
|
||||
converted.extend([normalise_model_name(i) for i in config['model_ids']])
|
||||
if DEBUG:
|
||||
try:
|
||||
AutoLLM.for_model(m, llm_config=config)
|
||||
except Exception as e:
|
||||
failed_initialized.append((m, e))
|
||||
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
@@ -680,22 +655,9 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
|
||||
data.extend([(m, v['architecture'], v['model_id'], v['installation'], v['backend'])])
|
||||
column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4)]
|
||||
|
||||
if len(data) == 0 and len(failed_initialized) > 0:
|
||||
termui.echo('Exception found while parsing models:\n', fg='yellow')
|
||||
for m, err in failed_initialized:
|
||||
termui.echo(f'- {m}: ', fg='yellow', nl=False)
|
||||
termui.echo(traceback.print_exception(None, err, None, limit=5), fg='red') # type: ignore[func-returns-value]
|
||||
sys.exit(1)
|
||||
|
||||
table = tabulate.tabulate(data, tablefmt='fancy_grid', headers=['LLM', 'Architecture', 'Models Id', 'Installation', 'Runtime'], maxcolwidths=column_widths)
|
||||
termui.echo(table, fg='white')
|
||||
|
||||
if DEBUG and len(failed_initialized) > 0:
|
||||
termui.echo('\nThe following models are supported but failed to initialize:\n')
|
||||
for m, err in failed_initialized:
|
||||
termui.echo(f'- {m}: ', fg='blue', nl=False)
|
||||
termui.echo(err, fg='red')
|
||||
|
||||
if show_available:
|
||||
if len(ids_in_local_store) == 0:
|
||||
termui.echo('No models available locally.')
|
||||
@@ -837,14 +799,14 @@ def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: in
|
||||
termui.echo(f'{prompt}', fg=input_fg)
|
||||
|
||||
if stream:
|
||||
stream_res: t.Iterator[StreamResponse] = client.generate_stream(prompt, **{**client._config(), **_memoized})
|
||||
stream_res: t.Iterator[StreamingResponse] = client.generate_stream(prompt, **{**client._config(), **_memoized})
|
||||
if output == 'pretty':
|
||||
termui.echo('\n\n==Responses==\n', fg='white')
|
||||
for it in stream_res:
|
||||
termui.echo(it.text, fg=generated_fg, nl=False)
|
||||
elif output == 'json':
|
||||
for it in stream_res:
|
||||
termui.echo(orjson.dumps(bentoml_cattr.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
termui.echo(orjson.dumps(converter.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else:
|
||||
for it in stream_res:
|
||||
termui.echo(it.text, fg=generated_fg, nl=False)
|
||||
@@ -852,11 +814,11 @@ def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: in
|
||||
res: Response = client.generate(prompt, **{**client._config(), **_memoized})
|
||||
if output == 'pretty':
|
||||
termui.echo('\n\n==Responses==\n', fg='white')
|
||||
termui.echo(res.responses[0], fg=generated_fg)
|
||||
termui.echo(res.outputs[0].text, fg=generated_fg)
|
||||
elif output == 'json':
|
||||
termui.echo(orjson.dumps(bentoml_cattr.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
termui.echo(orjson.dumps(converter.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else:
|
||||
termui.echo(res.responses, fg='white')
|
||||
termui.echo(res.outputs[0].text, fg='white')
|
||||
ctx.exit(0)
|
||||
|
||||
@cli.group(cls=Extensions, hidden=True, name='extension')
|
||||
|
||||
@@ -14,7 +14,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.container.generate import generate_containerfile
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import bento_complete_envvar
|
||||
from openllm_core.utils import bentoml_cattr
|
||||
from openllm_core.utils import converter
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
@@ -35,7 +35,7 @@ def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[Bento
|
||||
# Dockerfile inside bento, and it is not relevant to
|
||||
# construct_containerfile. Hence it is safe to set it to None here.
|
||||
# See https://github.com/bentoml/BentoML/issues/3399.
|
||||
docker_attrs = bentoml_cattr.unstructure(options.docker)
|
||||
docker_attrs = converter.unstructure(options.docker)
|
||||
# NOTE: if users specify a dockerfile_template, we will
|
||||
# save it to /env/docker/Dockerfile.template. This is necessary
|
||||
# for the reconstruction of the Dockerfile.
|
||||
|
||||
29
openllm-python/src/openllm/entrypoints/__init__.py
Normal file
29
openllm-python/src/openllm/entrypoints/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
'''Entrypoint for all third-party apps.
|
||||
|
||||
Currently support OpenAI compatible API.
|
||||
|
||||
Each module should implement the following API:
|
||||
|
||||
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
from . import hf as hf
|
||||
from . import openai as openai
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'openai': [], 'hf': []}
|
||||
|
||||
def mount_entrypoints(svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service:
|
||||
return openai.mount_to_svc(hf.mount_to_svc(svc, llm), llm)
|
||||
|
||||
__lazy = LazyModule(__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints})
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
518
openllm-python/src/openllm/entrypoints/_openapi.py
Normal file
518
openllm-python/src/openllm/entrypoints/_openapi.py
Normal file
@@ -0,0 +1,518 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import inspect
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.routing import Host
|
||||
from starlette.routing import Mount
|
||||
from starlette.routing import Route
|
||||
from starlette.schemas import EndpointInfo
|
||||
from starlette.schemas import SchemaGenerator
|
||||
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core.utils import first_not_none
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import AttrsInstance
|
||||
|
||||
import bentoml
|
||||
|
||||
P = ParamSpec('P')
|
||||
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
|
||||
# NOTE: OpenAI schema
|
||||
LIST_MODEL_SCHEMA = '''\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: >
|
||||
List and describe the various models available in the API.
|
||||
|
||||
You can refer to the available supported models with `openllm models` for more
|
||||
information.
|
||||
operationId: openai__list_models
|
||||
produces:
|
||||
- application/json
|
||||
summary: Describes a model offering that can be used with the API.
|
||||
tags:
|
||||
- OpenAI
|
||||
x-bentoml-name: list_models
|
||||
responses:
|
||||
'200':
|
||||
description: The Model object
|
||||
content:
|
||||
application/json:
|
||||
example:
|
||||
id: davinci
|
||||
object: model
|
||||
created: 1686935002
|
||||
owned_by: openai
|
||||
schema:
|
||||
$ref: '#/components/schemas/ModelList'
|
||||
'''
|
||||
CHAT_COMPLETION_SCHEMA = '''\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: >-
|
||||
Given a list of messages comprising a conversation, the model will return a
|
||||
response.
|
||||
operationId: openai__create_chat_completions
|
||||
produces:
|
||||
- application/json
|
||||
tags:
|
||||
- OpenAI
|
||||
x-bentoml-name: create_chat_completions
|
||||
summary: Creates a model response for the given chat conversation.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
examples:
|
||||
one-shot:
|
||||
summary: One-shot input example
|
||||
value:
|
||||
messages:
|
||||
- role: system
|
||||
content: You are a helpful assistant.
|
||||
- role: user
|
||||
content: Hello, I'm looking for a chatbot that can help me with my work.
|
||||
model: meta-llama--Llama-2-13-chat-hf
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
top_p: 0.43
|
||||
n: 1
|
||||
stream: false
|
||||
streaming:
|
||||
summary: Streaming input example
|
||||
value:
|
||||
messages:
|
||||
- role: system
|
||||
content: You are a helpful assistant.
|
||||
- role: user
|
||||
content: Hello, I'm looking for a chatbot that can help me with my work.
|
||||
model: meta-llama--Llama-2-13-chat-hf
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
top_p: 0.43
|
||||
n: 1
|
||||
stream: true
|
||||
stop:
|
||||
- "\\n"
|
||||
- "<|endoftext|>"
|
||||
schema:
|
||||
$ref: '#/components/schemas/ChatCompletionRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ChatCompletionResponse'
|
||||
examples:
|
||||
streaming:
|
||||
summary: Streaming output example
|
||||
value: >
|
||||
{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
one-shot:
|
||||
summary: One-shot output example
|
||||
value: >
|
||||
{"id": "chatcmpl-123", "object": "chat.completion", "created": 1677652288, "model": "gpt-3.5-turbo-0613", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello there, how may I assist you today?"}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}}
|
||||
'404':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
wrong-model:
|
||||
summary: Wrong model
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Model 'meta-llama--Llama-2-13-chat-hf' does not exists. Try 'GET /v1/models' to see available models.\\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 404
|
||||
}
|
||||
}
|
||||
description: NotFound
|
||||
'500':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
invalid-parameters:
|
||||
summary: Invalid parameters
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "`top_p` has to be a float > 0 and < 1, but is 4.0",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 500
|
||||
}
|
||||
}
|
||||
description: Internal Server Error
|
||||
'400':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
invalid-json:
|
||||
summary: Invalid JSON sent
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Invalid JSON input received (Check server log).",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 400
|
||||
}
|
||||
}
|
||||
invalid-prompt:
|
||||
summary: Invalid prompt
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Please provide a prompt.",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 400
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
'''
|
||||
COMPLETION_SCHEMA = '''\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: >-
|
||||
Given a prompt, the model will return one or more predicted completions, and
|
||||
can also return the probabilities of alternative tokens at each position. We
|
||||
recommend most users use our Chat completions API.
|
||||
operationId: openai__create_completions
|
||||
produces:
|
||||
- application/json
|
||||
tags:
|
||||
- OpenAI
|
||||
x-bentoml-name: create_completions
|
||||
summary: Creates a completion for the provided prompt and parameters.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CompletionRequest'
|
||||
examples:
|
||||
one-shot:
|
||||
summary: One-shot input example
|
||||
value:
|
||||
prompt: This is a test
|
||||
model: meta-llama--Llama-2-13-chat-hf
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
logprobs: 1
|
||||
top_p: 0.43
|
||||
n: 1
|
||||
stream: false
|
||||
streaming:
|
||||
summary: Streaming input example
|
||||
value:
|
||||
prompt: This is a test
|
||||
model: meta-llama--Llama-2-13-chat-hf
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
top_p: 0.43
|
||||
logprobs: 1
|
||||
n: 1
|
||||
stream: true
|
||||
stop:
|
||||
- "\\n"
|
||||
- "<|endoftext|>"
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CompletionResponse'
|
||||
examples:
|
||||
one-shot:
|
||||
summary: One-shot output example
|
||||
value:
|
||||
id: cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7
|
||||
object: text_completion
|
||||
created: 1589478378
|
||||
model: VAR_model_id
|
||||
choices:
|
||||
- text: This is indeed a test
|
||||
index: 0
|
||||
logprobs: null
|
||||
finish_reason: length
|
||||
usage:
|
||||
prompt_tokens: 5
|
||||
completion_tokens: 7
|
||||
total_tokens: 12
|
||||
streaming:
|
||||
summary: Streaming output example
|
||||
value:
|
||||
id: cmpl-7iA7iJjj8V2zOkCGvWF2hAkDWBQZe
|
||||
object: text_completion
|
||||
created: 1690759702
|
||||
choices:
|
||||
- text: This
|
||||
index: 0
|
||||
logprobs: null
|
||||
finish_reason: null
|
||||
model: gpt-3.5-turbo-instruct
|
||||
'404':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
wrong-model:
|
||||
summary: Wrong model
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Model 'meta-llama--Llama-2-13-chat-hf' does not exists. Try 'GET /v1/models' to see available models.\\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 404
|
||||
}
|
||||
}
|
||||
description: NotFound
|
||||
'500':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
invalid-parameters:
|
||||
summary: Invalid parameters
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "`top_p` has to be a float > 0 and < 1, but is 4.0",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 500
|
||||
}
|
||||
}
|
||||
description: Internal Server Error
|
||||
'400':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
examples:
|
||||
invalid-json:
|
||||
summary: Invalid JSON sent
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Invalid JSON input received (Check server log).",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 400
|
||||
}
|
||||
}
|
||||
invalid-prompt:
|
||||
summary: Invalid prompt
|
||||
value: >
|
||||
{
|
||||
"error": {
|
||||
"message": "Please provide a prompt.",
|
||||
"type": "invalid_request_error",
|
||||
"object": "error",
|
||||
"param": null,
|
||||
"code": 400
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
'''
|
||||
HF_AGENT_SCHEMA = '''\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: Generate instruction for given HF Agent chain for all OpenLLM supported models.
|
||||
operationId: hf__agent
|
||||
summary: Generate instruction for given HF Agent.
|
||||
tags:
|
||||
- HF
|
||||
x-bentoml-name: hf_agent
|
||||
produces:
|
||||
- application/json
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AgentRequest'
|
||||
example:
|
||||
inputs: "Is the following `text` positive or negative?"
|
||||
parameters:
|
||||
text: "This is a positive text."
|
||||
stop: ["\n"]
|
||||
required: true
|
||||
responses:
|
||||
200:
|
||||
description: Successfull generated instruction.
|
||||
content:
|
||||
application/json:
|
||||
example:
|
||||
- generated_text: "This is a generated instruction."
|
||||
schema:
|
||||
$ref: '#/components/schemas/AgentResponse'
|
||||
400:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AgentErrorResponse'
|
||||
description: Bad Request
|
||||
500:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AgentErrorResponse'
|
||||
description: Not Found
|
||||
'''
|
||||
|
||||
def add_schema_definitions(append_str: str) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]:
|
||||
def docstring_decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
if func.__doc__ is None: func.__doc__ = ''
|
||||
func.__doc__ = func.__doc__.strip() + '\n\n' + append_str.strip()
|
||||
return func
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
|
||||
endpoints_info: list[EndpointInfo] = []
|
||||
for route in routes:
|
||||
if isinstance(route, (Mount, Host)):
|
||||
routes = route.routes or []
|
||||
path = self._remove_converter(route.path) if isinstance(route, Mount) else ''
|
||||
sub_endpoints = [EndpointInfo(path=f'{path}{sub_endpoint.path}', http_method=sub_endpoint.http_method, func=sub_endpoint.func) for sub_endpoint in self.get_endpoints(routes)]
|
||||
endpoints_info.extend(sub_endpoints)
|
||||
elif not isinstance(route, Route) or not route.include_in_schema:
|
||||
continue
|
||||
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint) or isinstance(route.endpoint, functools.partial):
|
||||
endpoint = route.endpoint.func if isinstance(route.endpoint, functools.partial) else route.endpoint
|
||||
path = self._remove_converter(route.path)
|
||||
for method in route.methods or ['GET']:
|
||||
if method == 'HEAD': continue
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), endpoint))
|
||||
else:
|
||||
path = self._remove_converter(route.path)
|
||||
for method in ['get', 'post', 'put', 'patch', 'delete', 'options']:
|
||||
if not hasattr(route.endpoint, method): continue
|
||||
func = getattr(route.endpoint, method)
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), func))
|
||||
return endpoints_info
|
||||
|
||||
def get_schema(self, routes: list[BaseRoute], mount_path: str | None = None) -> dict[str, t.Any]:
|
||||
schema = dict(self.base_schema)
|
||||
schema.setdefault('paths', {})
|
||||
endpoints_info = self.get_endpoints(routes)
|
||||
if mount_path: mount_path = f'/{mount_path}' if not mount_path.startswith('/') else mount_path
|
||||
|
||||
for endpoint in endpoints_info:
|
||||
parsed = self.parse_docstring(endpoint.func)
|
||||
if not parsed: continue
|
||||
|
||||
path = endpoint.path if mount_path is None else mount_path + endpoint.path
|
||||
if path not in schema['paths']: schema['paths'][path] = {}
|
||||
schema['paths'][path][endpoint.http_method] = parsed
|
||||
|
||||
return schema
|
||||
|
||||
def get_generator(title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None) -> OpenLLMSchemaGenerator:
|
||||
base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION)
|
||||
if components: base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
|
||||
if tags is not None and tags: base_schema['tags'] = tags
|
||||
return OpenLLMSchemaGenerator(base_schema)
|
||||
|
||||
def component_schema_generator(attr_cls: type[AttrsInstance], description: str | None = None) -> dict[str, t.Any]:
|
||||
schema: dict[str, t.Any] = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
|
||||
schema['description'] = first_not_none(getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}')
|
||||
for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc]
|
||||
attr_type = field.type
|
||||
origin_type = t.get_origin(attr_type)
|
||||
args_type = t.get_args(attr_type)
|
||||
|
||||
# Map Python types to OpenAPI schema types
|
||||
if attr_type == str: schema_type = 'string'
|
||||
elif attr_type == int: schema_type = 'integer'
|
||||
elif attr_type == float: schema_type = 'number'
|
||||
elif attr_type == bool: schema_type = 'boolean'
|
||||
elif origin_type is list or origin_type is tuple:
|
||||
schema_type = 'array'
|
||||
elif origin_type is dict:
|
||||
schema_type = 'object'
|
||||
# Assuming string keys for simplicity, and handling Any type for values
|
||||
prop_schema = {
|
||||
'type': 'object',
|
||||
'additionalProperties':
|
||||
True if args_type[1] is t.Any else {
|
||||
'type': 'string'
|
||||
} # Simplified
|
||||
}
|
||||
elif attr_type == t.Optional[str]:
|
||||
schema_type = 'string'
|
||||
elif origin_type is t.Union and t.Any in args_type:
|
||||
schema_type = 'object'
|
||||
prop_schema = {
|
||||
'type': 'object',
|
||||
'additionalProperties': True # Allows any type of values
|
||||
}
|
||||
else:
|
||||
schema_type = 'string'
|
||||
|
||||
if 'prop_schema' not in locals(): prop_schema = {'type': schema_type}
|
||||
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory): prop_schema['default'] = field.default # type: ignore[arg-type]
|
||||
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)): schema['required'].append(field.name)
|
||||
schema['properties'][field.name] = prop_schema
|
||||
locals().pop('prop_schema', None)
|
||||
|
||||
return schema
|
||||
|
||||
class MKSchema:
|
||||
def __init__(self, it: dict[str, t.Any]) -> None:
|
||||
self.it = it
|
||||
|
||||
def asdict(self) -> dict[str, t.Any]:
|
||||
return self.it
|
||||
|
||||
def append_schemas(svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend') -> bentoml.Service:
|
||||
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
|
||||
from bentoml._internal.service.openapi.specification import OpenAPISpecification
|
||||
svc_schema: t.Any = svc.openapi_spec
|
||||
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)): svc_schema = svc_schema.asdict()
|
||||
if 'tags' in generated_schema:
|
||||
if tags_order == 'prepend': svc_schema['tags'] = generated_schema['tags'] + svc_schema['tags']
|
||||
elif tags_order == 'append': svc_schema['tags'].extend(generated_schema['tags'])
|
||||
else: raise ValueError(f'Invalid tags_order: {tags_order}')
|
||||
if 'components' in generated_schema: svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
|
||||
svc_schema['paths'].update(generated_schema['paths'])
|
||||
|
||||
from bentoml._internal.service import openapi # HACK: mk this attribute until we have a better way to add starlette schemas.
|
||||
|
||||
# yapf: disable
|
||||
def mk_generate_spec(svc:bentoml.Service,openapi_version:str=OPENAPI_VERSION)->MKSchema:return MKSchema(svc_schema)
|
||||
def mk_asdict(self:OpenAPISpecification)->dict[str,t.Any]:return svc_schema
|
||||
openapi.generate_spec=mk_generate_spec
|
||||
setattr(OpenAPISpecification, 'asdict', mk_asdict)
|
||||
# yapf: disable
|
||||
return svc
|
||||
75
openllm-python/src/openllm/entrypoints/hf.py
Normal file
75
openllm-python/src/openllm/entrypoints/hf.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
import orjson
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from openllm_core.utils import converter
|
||||
|
||||
from ._openapi import HF_AGENT_SCHEMA
|
||||
from ._openapi import add_schema_definitions
|
||||
from ._openapi import append_schemas
|
||||
from ._openapi import get_generator
|
||||
from ..protocol.hf import AgentErrorResponse
|
||||
from ..protocol.hf import AgentRequest
|
||||
from ..protocol.hf import AgentResponse
|
||||
|
||||
schemas = get_generator('hf',
|
||||
components=[AgentRequest, AgentResponse, AgentErrorResponse],
|
||||
tags=[{
|
||||
'name': 'HF',
|
||||
'description': 'Includes HF Agent support',
|
||||
'externalDocs': 'https://huggingface.co/docs/transformers/main_classes/agent'
|
||||
}])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False)])
|
||||
mount_path = '/hf'
|
||||
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, generated_schema, tags_order='append')
|
||||
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
return JSONResponse(converter.unstructure(AgentErrorResponse(message=message, error_code=status_code.value)), status_code=status_code.value)
|
||||
|
||||
@add_schema_definitions(HF_AGENT_SCHEMA)
|
||||
async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), AgentRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
|
||||
stop = request.parameters.pop('stop', ['\n'])
|
||||
try:
|
||||
result = await llm.generate(request.inputs, stop=stop, **request.parameters)
|
||||
return JSONResponse(converter.unstructure([AgentResponse(generated_text=result.outputs[0].text)]), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
logger.error('Error while generating: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Error while generating (Check server log).')
|
||||
|
||||
def openapi_schema(req: Request) -> Response:
|
||||
return schemas.OpenAPIResponse(req)
|
||||
305
openllm-python/src/openllm/entrypoints/openai.py
Normal file
305
openllm-python/src/openllm/entrypoints/openai.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
import orjson
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.responses import StreamingResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from openllm_core._schemas import SampleLogprobs
|
||||
from openllm_core.utils import converter
|
||||
from openllm_core.utils import gen_random_uuid
|
||||
|
||||
from ._openapi import CHAT_COMPLETION_SCHEMA
|
||||
from ._openapi import COMPLETION_SCHEMA
|
||||
from ._openapi import LIST_MODEL_SCHEMA
|
||||
from ._openapi import add_schema_definitions
|
||||
from ._openapi import append_schemas
|
||||
from ._openapi import get_generator
|
||||
from ..protocol.openai import ChatCompletionRequest
|
||||
from ..protocol.openai import ChatCompletionResponse
|
||||
from ..protocol.openai import ChatCompletionResponseChoice
|
||||
from ..protocol.openai import ChatCompletionResponseStreamChoice
|
||||
from ..protocol.openai import ChatCompletionStreamResponse
|
||||
from ..protocol.openai import ChatMessage
|
||||
from ..protocol.openai import CompletionRequest
|
||||
from ..protocol.openai import CompletionResponse
|
||||
from ..protocol.openai import CompletionResponseChoice
|
||||
from ..protocol.openai import CompletionResponseStreamChoice
|
||||
from ..protocol.openai import CompletionStreamResponse
|
||||
from ..protocol.openai import Delta
|
||||
from ..protocol.openai import ErrorResponse
|
||||
from ..protocol.openai import LogProbs
|
||||
from ..protocol.openai import ModelCard
|
||||
from ..protocol.openai import ModelList
|
||||
from ..protocol.openai import UsageInfo
|
||||
from ..protocol.openai import get_conversation_prompt
|
||||
|
||||
schemas = get_generator(
|
||||
'openai',
|
||||
components=[ErrorResponse, ModelList, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionStreamResponse, CompletionRequest, CompletionResponse, CompletionStreamResponse],
|
||||
tags=[{
|
||||
'name': 'OpenAI',
|
||||
'description': 'OpenAI Compatible API support',
|
||||
'externalDocs': 'https://platform.openai.com/docs/api-reference/completions/object'
|
||||
}])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import AttrsInstance
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
def jsonify_attr(obj: AttrsInstance) -> str:
|
||||
return orjson.dumps(converter.unstructure(obj)).decode()
|
||||
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
return JSONResponse({'error': converter.unstructure(ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value)))}, status_code=status_code.value)
|
||||
|
||||
async def check_model(request: CompletionRequest | ChatCompletionRequest, model: str) -> JSONResponse | None:
|
||||
if request.model == model: return None
|
||||
return error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request."
|
||||
)
|
||||
|
||||
def create_logprobs(token_ids: list[int], id_logprobs: list[dict[int, float]], initial_text_offset: int = 0, *, llm: openllm.LLM[M, T]) -> LogProbs:
|
||||
# Create OpenAI-style logprobs.
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
for token_id, id_logprob in zip(token_ids, id_logprobs):
|
||||
token = llm.tokenizer.convert_ids_to_tokens(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(id_logprob[token_id])
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in id_logprob.items()})
|
||||
return logprobs
|
||||
|
||||
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
|
||||
app = Starlette(debug=True,
|
||||
routes=[
|
||||
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
|
||||
Route('/completions', functools.partial(create_completions, llm=llm), methods=['POST']),
|
||||
Route('/chat/completions', functools.partial(create_chat_completions, llm=llm), methods=['POST'])
|
||||
])
|
||||
mount_path = '/v1'
|
||||
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, generated_schema)
|
||||
|
||||
# GET /v1/models
|
||||
@add_schema_definitions(LIST_MODEL_SCHEMA)
|
||||
def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value)
|
||||
|
||||
# POST /v1/chat/completions
|
||||
@add_schema_definitions(CHAT_COMPLETION_SCHEMA)
|
||||
async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
# TODO: Check for length based on model context_length
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), ChatCompletionRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received chat completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None: return err_check
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
|
||||
created_time = int(time.monotonic())
|
||||
prompt = await get_conversation_prompt(request, llm.config)
|
||||
config = llm.config.with_openai_request(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_response_json(index: int, text: str, finish_reason: str | None = None) -> str:
|
||||
return jsonify_attr(
|
||||
ChatCompletionStreamResponse(id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)]))
|
||||
|
||||
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
# first chunk with role
|
||||
for i in range(config['n']):
|
||||
yield f"data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role='assistant'), finish_reason=None)], model=model_name))}\n\n"
|
||||
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
yield f'data: {create_stream_response_json(output.index, output.text)}\n\n'
|
||||
if output.finish_reason is not None:
|
||||
yield f'data: {create_stream_response_json(output.index, "", output.finish_reason)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if request.stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[list[str]] = [[]] * config['n']
|
||||
token_ids: list[list[int]] = [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs])
|
||||
choices = [
|
||||
ChatCompletionResponseChoice(index=output.index, message=ChatMessage(role='assistant', content=output.text), finish_reason=output.finish_reason) for output in final_result.outputs
|
||||
]
|
||||
num_prompt_tokens, num_generated_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids)), sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens)
|
||||
response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
|
||||
if request.stream: # type: ignore[unreachable]
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
async def fake_stream_generator() -> t.AsyncGenerator[str, None]: # type: ignore[unreachable]
|
||||
yield f'data: {jsonify_attr(response)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
# POST /v1/completions
|
||||
@add_schema_definitions(COMPLETION_SCHEMA)
|
||||
async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
# TODO: Check for length based on model context_length
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), CompletionRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received legacy completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None: return err_check
|
||||
|
||||
if request.echo: return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
|
||||
if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
|
||||
if not request.prompt: return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
prompt = request.prompt
|
||||
# TODO: Support multiple prompts
|
||||
|
||||
if request.logprobs is not None and llm.__llm_backend__ == 'pt': # TODO: support logprobs generation for PyTorch
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`.")
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('cmpl')
|
||||
created_time = int(time.monotonic())
|
||||
config = llm.config.with_openai_request(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
# best_of != n then we don't stream
|
||||
# TODO: support use_beam_search
|
||||
stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of'])
|
||||
|
||||
def create_stream_response_json(index: int, text: str, logprobs: LogProbs | None = None, finish_reason: str | None = None) -> str:
|
||||
return jsonify_attr(
|
||||
CompletionStreamResponse(id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)]))
|
||||
|
||||
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
previous_num_tokens = [0] * config['n']
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs)[previous_num_tokens[i]:], llm=llm)
|
||||
else:
|
||||
logprobs = None
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs)}\n\n'
|
||||
if output.finish_reason is not None:
|
||||
logprobs = LogProbs() if request.logprobs is not None else None
|
||||
yield f'data: {create_stream_response_json(index=i, text="", logprobs=logprobs, finish_reason=output.finish_reason)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[list[str]] = [[]] * config['n']
|
||||
token_ids: list[list[int]] = [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs])
|
||||
|
||||
choices: list[CompletionResponseChoice] = []
|
||||
for output in final_result.outputs:
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs), llm=llm)
|
||||
else:
|
||||
logprobs = None
|
||||
choice_data = CompletionResponseChoice(index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids)) # XXX: We will always return prompt_token_ids, so this won't be None
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens)
|
||||
response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
|
||||
if request.stream:
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
async def fake_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
yield f'data: {jsonify_attr(response)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
11
openllm-python/src/openllm/models/__init__.py
generated
11
openllm-python/src/openllm/models/__init__.py
generated
@@ -1,11 +0,0 @@
|
||||
# This file is generated by tools/update-models-import.py. DO NOT EDIT MANUALLY!
|
||||
# To update this, run ./tools/update-models-import.py
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from openllm_core.utils import LazyModule
|
||||
_MODELS:set[str]={"auto", "baichuan", "chatglm", "dolly_v2", "falcon", "flan_t5", "gpt_neox", "llama", "mpt", "opt", "stablelm", "starcoder"}
|
||||
if t.TYPE_CHECKING:from . import auto as auto,baichuan as baichuan,chatglm as chatglm,dolly_v2 as dolly_v2,falcon as falcon,flan_t5 as flan_t5,gpt_neox as gpt_neox,llama as llama,mpt as mpt,opt as opt,stablelm as stablelm,starcoder as starcoder
|
||||
__lazy=LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS})
|
||||
__all__=__lazy.__all__
|
||||
__dir__=__lazy.__dir__
|
||||
__getattr__=__lazy.__getattr__
|
||||
@@ -1,66 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
from openllm_core.config import CONFIG_MAPPING as CONFIG_MAPPING
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
|
||||
from openllm_core.config import AutoConfig as AutoConfig
|
||||
from openllm_core.utils import LazyModule
|
||||
from openllm_core.utils import is_flax_available
|
||||
from openllm_core.utils import is_tf_available
|
||||
from openllm_core.utils import is_torch_available
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
'modeling_auto': ['MODEL_MAPPING_NAMES'],
|
||||
'modeling_flax_auto': ['MODEL_FLAX_MAPPING_NAMES'],
|
||||
'modeling_tf_auto': ['MODEL_TF_MAPPING_NAMES'],
|
||||
'modeling_vllm_auto': ['MODEL_VLLM_MAPPING_NAMES']
|
||||
}
|
||||
if t.TYPE_CHECKING:
|
||||
from .modeling_auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
|
||||
from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
|
||||
from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
|
||||
from .modeling_vllm_auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
|
||||
try:
|
||||
if not is_torch_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_auto'].extend(['AutoLLM', 'MODEL_MAPPING'])
|
||||
if t.TYPE_CHECKING: from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING, AutoLLM as AutoLLM
|
||||
try:
|
||||
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_vllm_auto'].extend(['AutoVLLM', 'MODEL_VLLM_MAPPING'])
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING, AutoVLLM as AutoVLLM
|
||||
try:
|
||||
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_flax_auto'].extend(['AutoFlaxLLM', 'MODEL_FLAX_MAPPING'])
|
||||
if t.TYPE_CHECKING:
|
||||
from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM
|
||||
try:
|
||||
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_tf_auto'].extend(['AutoTFLLM', 'MODEL_TF_MAPPING'])
|
||||
if t.TYPE_CHECKING: from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM
|
||||
|
||||
__lazy = LazyModule(__name__,
|
||||
os.path.abspath('__file__'),
|
||||
_import_structure,
|
||||
extra_objects={
|
||||
'CONFIG_MAPPING': CONFIG_MAPPING,
|
||||
'CONFIG_MAPPING_NAMES': CONFIG_MAPPING_NAMES,
|
||||
'AutoConfig': AutoConfig,
|
||||
})
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
@@ -1,181 +0,0 @@
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
import inflection
|
||||
|
||||
import openllm
|
||||
from openllm_core.utils import ReprMixin
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
from collections import _odict_items
|
||||
from collections import _odict_keys
|
||||
from collections import _odict_values
|
||||
|
||||
from _typeshed import SupportsIter
|
||||
|
||||
from openllm_core._typing_compat import LiteralString
|
||||
from openllm_core._typing_compat import LLMRunner
|
||||
ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
ConfigModelItemsView = _odict_items[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseAutoLLMClass:
|
||||
_model_mapping: t.ClassVar[_LazyAutoMapping]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.")
|
||||
|
||||
@classmethod
|
||||
def for_model(cls,
|
||||
model: str,
|
||||
/,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
llm_config: openllm.LLMConfig | None = None,
|
||||
ensure_available: bool = False,
|
||||
**attrs: t.Any) -> openllm.LLM[t.Any, t.Any]:
|
||||
'''The lower level API for creating a LLM instance.
|
||||
|
||||
```python
|
||||
>>> import openllm
|
||||
>>> llm = openllm.AutoLLM.for_model("flan-t5")
|
||||
```
|
||||
'''
|
||||
llm = cls.infer_class_from_name(model).from_pretrained(model_id=model_id, model_version=model_version, llm_config=llm_config, **attrs)
|
||||
if ensure_available: llm.save_pretrained()
|
||||
return llm
|
||||
|
||||
@classmethod
|
||||
def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner[t.Any, t.Any]:
|
||||
'''Create a LLM Runner for the given model name.
|
||||
|
||||
Args:
|
||||
model: The model name to instantiate.
|
||||
model_id: The pretrained model name to instantiate.
|
||||
**attrs: Additional keyword arguments passed along to the specific configuration class.
|
||||
|
||||
Returns:
|
||||
A LLM instance.
|
||||
'''
|
||||
runner_kwargs_name = set(inspect.signature(openllm.LLM[t.Any, t.Any].to_runner).parameters)
|
||||
runner_attrs = {k: v for k, v in attrs.items() if k in runner_kwargs_name}
|
||||
for k in runner_attrs:
|
||||
del attrs[k]
|
||||
return cls.for_model(model, model_id=model_id, **attrs).to_runner(**runner_attrs)
|
||||
|
||||
@classmethod
|
||||
def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]) -> None:
|
||||
'''Register a new model for this class.
|
||||
|
||||
Args:
|
||||
config_class: The configuration corresponding to the model to register.
|
||||
llm_class: The runnable to register.
|
||||
'''
|
||||
if hasattr(llm_class, 'config_class') and llm_class.config_class is not config_class:
|
||||
raise ValueError(
|
||||
f'The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix one of those so they match!'
|
||||
)
|
||||
cls._model_mapping.register(config_class, llm_class)
|
||||
|
||||
@classmethod
|
||||
def infer_class_from_name(cls, name: str) -> type[openllm.LLM[t.Any, t.Any]]:
|
||||
config_class = openllm.AutoConfig.infer_class_from_name(name)
|
||||
if config_class in cls._model_mapping: return cls._model_mapping[config_class]
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class ({config_class}) for {name}. Model name should be one of {', '.join(openllm.CONFIG_MAPPING.keys())} (Registered configuration class: {', '.join([i.__name__ for i in cls._model_mapping.keys()])})."
|
||||
)
|
||||
|
||||
def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
|
||||
if attr is None: return
|
||||
if isinstance(attr, tuple): return tuple(getattribute_from_module(module, a) for a in attr)
|
||||
if hasattr(module, attr): return getattr(module, attr)
|
||||
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the object at the top level.
|
||||
openllm_module = importlib.import_module('openllm')
|
||||
if module != openllm_module:
|
||||
try:
|
||||
return getattribute_from_module(openllm_module, attr)
|
||||
except ValueError:
|
||||
raise ValueError(f'Could not find {attr} neither in {module} nor in {openllm_module}!') from None
|
||||
raise ValueError(f'Could not find {attr} in {openllm_module}!')
|
||||
|
||||
class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
|
||||
|
||||
This OrderedDict values() and keys() returns the list instead, so you don't
|
||||
have to do list(mapping.values()) to get the list of values.
|
||||
"""
|
||||
def __init__(self, config_mapping: OrderedDict[LiteralString, LiteralString], model_mapping: OrderedDict[LiteralString, LiteralString]):
|
||||
self._config_mapping = config_mapping
|
||||
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
||||
self._model_mapping = model_mapping
|
||||
self._extra_content: dict[t.Any, t.Any] = {}
|
||||
self._modules: dict[str, types.ModuleType] = {}
|
||||
|
||||
def __getitem__(self, key: type[openllm.LLMConfig]) -> type[openllm.LLM[t.Any, t.Any]]:
|
||||
if key in self._extra_content: return self._extra_content[key]
|
||||
model_type = self._reverse_config_mapping[key.__name__]
|
||||
if model_type in self._model_mapping:
|
||||
return self._load_attr_from_module(model_type, self._model_mapping[model_type])
|
||||
# Maybe there was several model types associated with this config.
|
||||
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
|
||||
for mtype in model_types:
|
||||
if mtype in self._model_mapping: return self._load_attr_from_module(mtype, self._model_mapping[mtype])
|
||||
raise KeyError(key)
|
||||
|
||||
def _load_attr_from_module(self, model_type: str, attr: str) -> t.Any:
|
||||
module_name = inflection.underscore(model_type)
|
||||
if module_name not in self._modules:
|
||||
self._modules[module_name] = importlib.import_module(f'.{module_name}', 'openllm.models')
|
||||
return getattribute_from_module(self._modules[module_name], attr)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(set(self._config_mapping.keys()).intersection(self._model_mapping.keys())) + len(self._extra_content)
|
||||
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
return set(self._config_mapping.keys())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ReprMixin.__repr__(self)
|
||||
|
||||
def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]:
|
||||
yield from ((key, (value, self._model_mapping[key])) for key, value in self._config_mapping.items() if key in self._model_mapping)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.keys())
|
||||
|
||||
def keys(self) -> ConfigModelKeysView:
|
||||
return t.cast('ConfigModelKeysView',
|
||||
[self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys()))
|
||||
|
||||
def values(self) -> ConfigModelValuesView:
|
||||
return t.cast('ConfigModelValuesView',
|
||||
[self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(self._extra_content.values()))
|
||||
|
||||
def items(self) -> ConfigModelItemsView:
|
||||
return t.cast('ConfigModelItemsView', [(self._load_attr_from_module(key, self._config_mapping[key]), self._load_attr_from_module(key, self._model_mapping[key]))
|
||||
for key in self._model_mapping.keys()
|
||||
if key in self._config_mapping.keys()] + list(self._extra_content.items()))
|
||||
|
||||
def __iter__(self) -> t.Iterator[type[openllm.LLMConfig]]:
|
||||
return iter(t.cast('SupportsIter[t.Iterator[type[openllm.LLMConfig]]]', self.keys()))
|
||||
|
||||
def __contains__(self, item: t.Any) -> bool:
|
||||
if item in self._extra_content: return True
|
||||
if not hasattr(item, '__name__') or item.__name__ not in self._reverse_config_mapping: return False
|
||||
return self._reverse_config_mapping[item.__name__] in self._model_mapping
|
||||
|
||||
def register(self, key: t.Any, value: t.Any) -> None:
|
||||
if hasattr(key, '__name__') and key.__name__ in self._reverse_config_mapping:
|
||||
if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys():
|
||||
raise ValueError(f"'{key}' is already used by a OpenLLM model.")
|
||||
self._extra_content[key] = value
|
||||
|
||||
__all__ = ['BaseAutoLLMClass', '_LazyAutoMapping']
|
||||
@@ -1,15 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
|
||||
MODEL_MAPPING_NAMES = OrderedDict([('chatglm', 'ChatGLM'), ('dolly_v2', 'DollyV2'), ('falcon', 'Falcon'), ('flan_t5', 'FlanT5'), ('gpt_neox', 'GPTNeoX'), ('llama', 'Llama'), ('mpt', 'MPT'),
|
||||
('opt', 'OPT'), ('stablelm', 'StableLM'), ('starcoder', 'StarCoder'), ('baichuan', 'Baichuan')])
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
|
||||
class AutoLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_MAPPING
|
||||
@@ -1,14 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
|
||||
MODEL_FLAX_MAPPING_NAMES = OrderedDict([('flan_t5', 'FlaxFlanT5'), ('opt', 'FlaxOPT')])
|
||||
MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
|
||||
|
||||
class AutoFlaxLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_FLAX_MAPPING
|
||||
@@ -1,14 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
|
||||
MODEL_TF_MAPPING_NAMES = OrderedDict([('flan_t5', 'TFFlanT5'), ('opt', 'TFOPT')])
|
||||
MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
|
||||
|
||||
class AutoTFLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_TF_MAPPING
|
||||
@@ -1,15 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
|
||||
MODEL_VLLM_MAPPING_NAMES = OrderedDict([('baichuan', 'VLLMBaichuan'), ('dolly_v2', 'VLLMDollyV2'), ('falcon', 'VLLMFalcon'), ('gpt_neox', 'VLLMGPTNeoX'), ('mpt', 'VLLMMPT'),
|
||||
('opt', 'VLLMOPT'), ('stablelm', 'VLLMStableLM'), ('starcoder', 'VLLMStarCoder'), ('llama', 'VLLMLlama')])
|
||||
MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES)
|
||||
|
||||
class AutoVLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_VLLM_MAPPING
|
||||
@@ -1,37 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule
|
||||
from openllm.utils import is_cpm_kernels_available
|
||||
from openllm.utils import is_torch_available
|
||||
from openllm.utils import is_vllm_available
|
||||
from openllm_core.config.configuration_baichuan import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_baichuan import START_BAICHUAN_COMMAND_DOCSTRING as START_BAICHUAN_COMMAND_DOCSTRING
|
||||
from openllm_core.config.configuration_baichuan import BaichuanConfig as BaichuanConfig
|
||||
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_baichuan'] = ['Baichuan']
|
||||
if t.TYPE_CHECKING: from .modeling_baichuan import Baichuan as Baichuan
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_vllm_baichuan'] = ['VLLMBaichuan']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_baichuan import VLLMBaichuan as VLLMBaichuan
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
extra_objects={
|
||||
'DEFAULT_PROMPT_TEMPLATE': DEFAULT_PROMPT_TEMPLATE,
|
||||
'START_BAICHUAN_COMMAND_DOCSTRING': START_BAICHUAN_COMMAND_DOCSTRING,
|
||||
'BaichuanConfig': BaichuanConfig
|
||||
})
|
||||
@@ -1,15 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class Baichuan(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.float16): # type: ignore[attr-defined]
|
||||
outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
@@ -1,9 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMBaichuan(openllm.LLM['vllm.LLMEngine', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
@@ -1,29 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule
|
||||
from openllm.utils import is_cpm_kernels_available
|
||||
from openllm.utils import is_torch_available
|
||||
from openllm_core.config.configuration_chatglm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_chatglm import START_CHATGLM_COMMAND_DOCSTRING as START_CHATGLM_COMMAND_DOCSTRING
|
||||
from openllm_core.config.configuration_chatglm import ChatGLMConfig as ChatGLMConfig
|
||||
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available() or not is_cpm_kernels_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_chatglm'] = ['ChatGLM']
|
||||
if t.TYPE_CHECKING: from .modeling_chatglm import ChatGLM as ChatGLM
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
extra_objects={
|
||||
'DEFAULT_PROMPT_TEMPLATE': DEFAULT_PROMPT_TEMPLATE,
|
||||
'START_CHATGLM_COMMAND_DOCSTRING': START_CHATGLM_COMMAND_DOCSTRING,
|
||||
'ChatGLMConfig': ChatGLMConfig
|
||||
})
|
||||
@@ -1,17 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
|
||||
class ChatGLM(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrainedTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> tuple[str, list[tuple[str, str]]]:
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
self.model.eval()
|
||||
# Only use half precision if the model is not yet quantized
|
||||
if self.config.use_half_precision: self.model.half()
|
||||
return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
@@ -1,36 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule
|
||||
from openllm.utils import is_torch_available
|
||||
from openllm.utils import is_vllm_available
|
||||
from openllm_core.config.configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_dolly_v2 import START_DOLLY_V2_COMMAND_DOCSTRING as START_DOLLY_V2_COMMAND_DOCSTRING
|
||||
from openllm_core.config.configuration_dolly_v2 import DollyV2Config as DollyV2Config
|
||||
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_dolly_v2'] = ['DollyV2']
|
||||
if t.TYPE_CHECKING: from .modeling_dolly_v2 import DollyV2 as DollyV2
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_vllm_dolly_v2'] = ['VLLMDollyV2']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_dolly_v2 import VLLMDollyV2 as VLLMDollyV2
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
extra_objects={
|
||||
'DEFAULT_PROMPT_TEMPLATE': DEFAULT_PROMPT_TEMPLATE,
|
||||
'START_DOLLY_V2_COMMAND_DOCSTRING': START_DOLLY_V2_COMMAND_DOCSTRING,
|
||||
'DollyV2Config': DollyV2Config
|
||||
})
|
||||
@@ -1,141 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
from openllm_core._typing_compat import overload
|
||||
from openllm_core.config.configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_dolly_v2 import END_KEY
|
||||
from openllm_core.config.configuration_dolly_v2 import RESPONSE_KEY
|
||||
from openllm_core.config.configuration_dolly_v2 import get_special_token_id
|
||||
if t.TYPE_CHECKING: import torch, transformers, tensorflow as tf
|
||||
else:
|
||||
torch, transformers, tf = openllm.utils.LazyLoader('torch', globals(), 'torch'), openllm.utils.LazyLoader('transformers', globals(),
|
||||
'transformers'), openllm.utils.LazyLoader('tf', globals(), 'tensorflow')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]:
|
||||
...
|
||||
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline:
|
||||
# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information.
|
||||
class InstructionTextGenerationPipeline(transformers.Pipeline):
|
||||
def __init__(self, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any):
|
||||
super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
preprocess_params: dict[str, t.Any] = {}
|
||||
# newer versions of the tokenizer configure the response key as a special token. newer versions still may
|
||||
# append a newline to yield a single token. find whatever token is configured for the response key.
|
||||
tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None)
|
||||
response_key_token_id = None
|
||||
end_key_token_id = None
|
||||
if tokenizer_response_key:
|
||||
try:
|
||||
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
|
||||
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
|
||||
# Ensure generation stops once it generates "### End"
|
||||
generate_kwargs['eos_token_id'] = end_key_token_id
|
||||
except ValueError:
|
||||
pass
|
||||
forward_params = generate_kwargs
|
||||
postprocess_params = {'response_key_token_id': response_key_token_id, 'end_key_token_id': end_key_token_id}
|
||||
if return_full_text is not None: postprocess_params['return_full_text'] = return_full_text
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def preprocess(self, input_: str, **generate_kwargs: t.Any) -> t.Dict[str, t.Any]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_)
|
||||
inputs = self.tokenizer(prompt_text, return_tensors='pt')
|
||||
inputs['prompt_text'] = prompt_text
|
||||
inputs['instruction_text'] = input_
|
||||
return t.cast(t.Dict[str, t.Any], inputs)
|
||||
|
||||
def _forward(self, input_tensors: dict[str, t.Any], **generate_kwargs: t.Any) -> transformers.utils.generic.ModelOutput:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
input_ids, attention_mask = input_tensors['input_ids'], input_tensors.get('attention_mask', None)
|
||||
if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1
|
||||
else: in_b = input_ids.shape[0]
|
||||
generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None,
|
||||
attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
**generate_kwargs)
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == 'pt':
|
||||
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
elif self.framework == 'tf':
|
||||
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
instruction_text = input_tensors.pop('instruction_text')
|
||||
return {'generated_sequence': generated_sequence, 'input_ids': input_ids, 'instruction_text': instruction_text}
|
||||
|
||||
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
_generated_sequence, instruction_text = model_outputs['generated_sequence'][0], model_outputs['instruction_text']
|
||||
generated_sequence: list[list[int]] = _generated_sequence.numpy().tolist()
|
||||
records: list[dict[t.Literal['generated_text'], str]] = []
|
||||
for sequence in generated_sequence:
|
||||
# The response will be set to this variable if we can identify it.
|
||||
decoded = None
|
||||
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
|
||||
if response_key_token_id and end_key_token_id:
|
||||
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
|
||||
# prompt, we should definitely find it. We will return the tokens found after this token.
|
||||
try:
|
||||
response_pos = sequence.index(response_key_token_id)
|
||||
except ValueError:
|
||||
response_pos = None
|
||||
if response_pos is None:
|
||||
logger.warning('Could not find response key %s in: %s', response_key_token_id, sequence)
|
||||
if response_pos:
|
||||
# Next find where "### End" is located. The model has been trained to end its responses with this
|
||||
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
||||
# this token, as the response could be truncated. If we don't find it then just return everything
|
||||
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
|
||||
try:
|
||||
end_pos = sequence.index(end_key_token_id)
|
||||
except ValueError:
|
||||
end_pos = None
|
||||
decoded = self.tokenizer.decode(sequence[response_pos + 1:end_pos]).strip()
|
||||
if not decoded:
|
||||
# Otherwise we'll decode everything and use a regex to find the response and end.
|
||||
fully_decoded = self.tokenizer.decode(sequence)
|
||||
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
||||
# end.
|
||||
m = re.search(r'#+\s*Response:\s*(.+?)#+\s*End', fully_decoded, flags=re.DOTALL)
|
||||
if m: decoded = m.group(1).strip()
|
||||
else:
|
||||
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
||||
# return everything after "### Response:".
|
||||
m = re.search(r'#+\s*Response:\s*(.+)', fully_decoded, flags=re.DOTALL)
|
||||
if m: decoded = m.group(1).strip()
|
||||
else: logger.warning('Failed to find response in:\n%s', fully_decoded)
|
||||
# If the full text is requested, then append the decoded text to the original instruction.
|
||||
# This technically isn't the full text, as we format the instruction in the prompt the model has been
|
||||
# trained on, but to the client it will appear to be the full text.
|
||||
if return_full_text: decoded = f'{instruction_text}\n{decoded}'
|
||||
records.append({'generated_text': t.cast(str, decoded)})
|
||||
return records
|
||||
|
||||
return InstructionTextGenerationPipeline() if _init else InstructionTextGenerationPipeline
|
||||
|
||||
class DollyV2(openllm.LLM['transformers.Pipeline', 'transformers.PreTrainedTokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.bfloat16}, {}
|
||||
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline:
|
||||
return get_pipeline(transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), self.tokenizer, _init=True, return_full_text=self.config.return_full_text)
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
with torch.inference_mode():
|
||||
return self.model(prompt, return_full_text=llm_config.return_full_text, generation_config=llm_config.to_generation_config())
|
||||
@@ -1,12 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VLLMDollyV2(openllm.LLM['vllm.LLMEngine', 'transformers.PreTrainedTokenizer']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
@@ -1,36 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule
|
||||
from openllm.utils import is_torch_available
|
||||
from openllm.utils import is_vllm_available
|
||||
from openllm_core.config.configuration_falcon import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_falcon import START_FALCON_COMMAND_DOCSTRING as START_FALCON_COMMAND_DOCSTRING
|
||||
from openllm_core.config.configuration_falcon import FalconConfig as FalconConfig
|
||||
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_falcon'] = ['Falcon']
|
||||
if t.TYPE_CHECKING: from .modeling_falcon import Falcon as Falcon
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_vllm_falcon'] = ['VLLMFalcon']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_falcon import VLLMFalcon as VLLMFalcon
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
extra_objects={
|
||||
'DEFAULT_PROMPT_TEMPLATE': DEFAULT_PROMPT_TEMPLATE,
|
||||
'START_FALCON_COMMAND_DOCSTRING': START_FALCON_COMMAND_DOCSTRING,
|
||||
'FalconConfig': FalconConfig
|
||||
})
|
||||
@@ -1,22 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else:
|
||||
torch, transformers = openllm.utils.LazyLoader('torch', globals(), 'torch'), openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
|
||||
class Falcon(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {'torch_dtype': torch.bfloat16, 'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
eos_token_id, inputs = attrs.pop('eos_token_id', self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.float16): # type: ignore[attr-defined]
|
||||
return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
@@ -1,12 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VLLMFalcon(openllm.LLM['vllm.LLMEngine', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
@@ -1,37 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule
|
||||
from openllm.utils import is_flax_available
|
||||
from openllm.utils import is_tf_available
|
||||
from openllm.utils import is_torch_available
|
||||
from openllm_core.config.configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_flan_t5 import START_FLAN_T5_COMMAND_DOCSTRING as START_FLAN_T5_COMMAND_DOCSTRING
|
||||
from openllm_core.config.configuration_flan_t5 import FlanT5Config as FlanT5Config
|
||||
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_flan_t5'] = ['FlanT5']
|
||||
if t.TYPE_CHECKING: from .modeling_flan_t5 import FlanT5 as FlanT5
|
||||
try:
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_flax_flan_t5'] = ['FlaxFlanT5']
|
||||
if t.TYPE_CHECKING: from .modeling_flax_flan_t5 import FlaxFlanT5 as FlaxFlanT5
|
||||
try:
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_tf_flan_t5'] = ['TFFlanT5']
|
||||
if t.TYPE_CHECKING: from .modeling_tf_flan_t5 import TFFlanT5 as TFFlanT5
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
@@ -1,17 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
|
||||
class FlanT5(openllm.LLM['transformers.T5ForConditionalGeneration', 'transformers.T5TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.config.configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class FlaxFlanT5(openllm.LLM['transformers.FlaxT5ForConditionalGeneration', 'transformers.T5TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
decoder_start_token_id: int | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if decoder_start_token_id is None: decoder_start_token_id = 0
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'repetition_penalty': repetition_penalty,
|
||||
'decoder_start_token_id': decoder_start_token_id
|
||||
}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
# NOTE: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main as it is required for encoder-decoder generation.
|
||||
decoder_start_token_id = attrs.pop('decoder_start_token_id', 0)
|
||||
return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors='np')['input_ids'],
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
decoder_start_token_id=decoder_start_token_id).sequences,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True)
|
||||
@@ -1,14 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class TFFlanT5(openllm.LLM['transformers.TFT5ForConditionalGeneration', 'transformers.T5TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors='tf').input_ids,
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
@@ -1,36 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule
|
||||
from openllm.utils import is_torch_available
|
||||
from openllm.utils import is_vllm_available
|
||||
from openllm_core.config.configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_gpt_neox import START_GPT_NEOX_COMMAND_DOCSTRING as START_GPT_NEOX_COMMAND_DOCSTRING
|
||||
from openllm_core.config.configuration_gpt_neox import GPTNeoXConfig as GPTNeoXConfig
|
||||
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_gpt_neox'] = ['GPTNeoX']
|
||||
if t.TYPE_CHECKING: from .modeling_gpt_neox import GPTNeoX as GPTNeoX
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_vllm_gpt_neox'] = ['VLLMGPTNeoX']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_gpt_neox import VLLMGPTNeoX as VLLMGPTNeoX
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
extra_objects={
|
||||
'DEFAULT_PROMPT_TEMPLATE': DEFAULT_PROMPT_TEMPLATE,
|
||||
'START_GPT_NEOX_COMMAND_DOCSTRING': START_GPT_NEOX_COMMAND_DOCSTRING,
|
||||
'GPTNeoXConfig': GPTNeoXConfig
|
||||
})
|
||||
@@ -1,16 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GPTNeoX(openllm.LLM['transformers.GPTNeoXForCausalLM', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
|
||||
@@ -1,9 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMGPTNeoX(openllm.LLM['vllm.LLMEngine', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
@@ -1,38 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule
|
||||
from openllm.utils import is_torch_available
|
||||
from openllm.utils import is_vllm_available
|
||||
from openllm_core.config.configuration_llama import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_llama import PROMPT_MAPPING as PROMPT_MAPPING
|
||||
from openllm_core.config.configuration_llama import START_LLAMA_COMMAND_DOCSTRING as START_LLAMA_COMMAND_DOCSTRING
|
||||
from openllm_core.config.configuration_llama import LlamaConfig as LlamaConfig
|
||||
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_vllm_llama'] = ['VLLMLlama']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_llama import VLLMLlama as VLLMLlama
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_llama'] = ['Llama']
|
||||
if t.TYPE_CHECKING: from .modeling_llama import Llama as Llama
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
extra_objects={
|
||||
'DEFAULT_PROMPT_TEMPLATE': DEFAULT_PROMPT_TEMPLATE,
|
||||
'START_LLAMA_COMMAND_DOCSTRING': START_LLAMA_COMMAND_DOCSTRING,
|
||||
'LlamaConfig': LlamaConfig,
|
||||
'PROMPT_MAPPING': PROMPT_MAPPING
|
||||
})
|
||||
@@ -1,14 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
|
||||
class Llama(openllm.LLM['transformers.LlamaForCausalLM', 'transformers.LlamaTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
@@ -1,8 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMLlama(openllm.LLM['vllm.LLMEngine', 'transformers.LlamaTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
@@ -1,38 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule
|
||||
from openllm.utils import is_torch_available
|
||||
from openllm.utils import is_vllm_available
|
||||
from openllm_core.config.configuration_mpt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_mpt import PROMPT_MAPPING as PROMPT_MAPPING
|
||||
from openllm_core.config.configuration_mpt import START_MPT_COMMAND_DOCSTRING as START_MPT_COMMAND_DOCSTRING
|
||||
from openllm_core.config.configuration_mpt import MPTConfig as MPTConfig
|
||||
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_mpt'] = ['MPT']
|
||||
if t.TYPE_CHECKING: from .modeling_mpt import MPT as MPT
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_vllm_mpt'] = ['VLLMMPT']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_mpt import VLLMMPT as VLLMMPT
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
extra_objects={
|
||||
'DEFAULT_PROMPT_TEMPLATE': DEFAULT_PROMPT_TEMPLATE,
|
||||
'START_MPT_COMMAND_DOCSTRING': START_MPT_COMMAND_DOCSTRING,
|
||||
'MPTConfig': MPTConfig,
|
||||
'PROMPT_MAPPING': PROMPT_MAPPING
|
||||
})
|
||||
@@ -1,88 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm.utils import generate_labels
|
||||
from openllm.utils import is_triton_available
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_mpt_config(model_id_or_path: str,
|
||||
max_sequence_length: int,
|
||||
device: torch.device | str | int | None,
|
||||
device_map: str | None = None,
|
||||
trust_remote_code: bool = True) -> transformers.PretrainedConfig:
|
||||
import torch
|
||||
config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
|
||||
if hasattr(config, 'init_device') and device_map is None and isinstance(device, (str, torch.device)):
|
||||
config.init_device = str(device)
|
||||
if hasattr(config, 'attn_config') and is_triton_available(): config.attn_config['attn_impl'] = 'triton'
|
||||
else:
|
||||
logger.debug(
|
||||
"'triton' is not available, Flash Attention will use the default Torch implementation. For faster inference, make sure to install triton with 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'"
|
||||
)
|
||||
# setting max_seq_len
|
||||
config.max_seq_len = max_sequence_length
|
||||
return config
|
||||
|
||||
class MPT(openllm.LLM['transformers.PreTrainedModel', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
|
||||
import torch
|
||||
import transformers
|
||||
_, tokenizer_attrs = self.llm_parameters
|
||||
torch_dtype = attrs.pop('torch_dtype', torch.bfloat16 if torch.cuda.is_available() else torch.float32)
|
||||
device_map = attrs.pop('device_map', None)
|
||||
attrs.pop('low_cpu_mem_usage', None)
|
||||
config = get_mpt_config(self.model_id, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
|
||||
if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map, **attrs)
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={'tokenizer': tokenizer}, labels=generate_labels(self))
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.PreTrainedModel:
|
||||
import transformers
|
||||
torch_dtype = attrs.pop('torch_dtype', torch.bfloat16 if torch.cuda.is_available() else torch.float32)
|
||||
device_map = attrs.pop('device_map', None)
|
||||
trust_remote_code = attrs.pop('trust_remote_code', True)
|
||||
config = get_mpt_config(self._bentomodel.path, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code,)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path,
|
||||
config=config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
**attrs)
|
||||
model.tie_weights()
|
||||
return model
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
attrs = {
|
||||
'do_sample': False if llm_config['temperature'] == 0 else True,
|
||||
'eos_token_id': self.tokenizer.eos_token_id,
|
||||
'pad_token_id': self.tokenizer.pad_token_id,
|
||||
'generation_config': llm_config.to_generation_config()
|
||||
}
|
||||
with torch.inference_mode():
|
||||
if torch.cuda.is_available():
|
||||
with torch.autocast('cuda', torch.float16): # type: ignore[attr-defined]
|
||||
generated_tensors = self.model.generate(**inputs, **attrs)
|
||||
else:
|
||||
generated_tensors = self.model.generate(**inputs, **attrs)
|
||||
return self.tokenizer.batch_decode(generated_tensors, skip_special_tokens=True)
|
||||
@@ -1,9 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers, vllm
|
||||
|
||||
class VLLMMPT(openllm.LLM['vllm.LLMEngine', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
@@ -1,52 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule
|
||||
from openllm.utils import is_flax_available
|
||||
from openllm.utils import is_tf_available
|
||||
from openllm.utils import is_torch_available
|
||||
from openllm.utils import is_vllm_available
|
||||
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_opt import START_OPT_COMMAND_DOCSTRING as START_OPT_COMMAND_DOCSTRING
|
||||
from openllm_core.config.configuration_opt import OPTConfig as OPTConfig
|
||||
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_opt'] = ['OPT']
|
||||
if t.TYPE_CHECKING: from .modeling_opt import OPT as OPT
|
||||
try:
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_flax_opt'] = ['FlaxOPT']
|
||||
if t.TYPE_CHECKING: from .modeling_flax_opt import FlaxOPT as FlaxOPT
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_vllm_opt'] = ['VLLMOPT']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_opt import VLLMOPT as VLLMOPT
|
||||
try:
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_tf_opt'] = ['TFOPT']
|
||||
if t.TYPE_CHECKING: from .modeling_tf_opt import TFOPT as TFOPT
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
extra_objects={
|
||||
'DEFAULT_PROMPT_TEMPLATE': DEFAULT_PROMPT_TEMPLATE,
|
||||
'START_OPT_COMMAND_DOCSTRING': START_OPT_COMMAND_DOCSTRING,
|
||||
'OPTConfig': OPTConfig,
|
||||
})
|
||||
@@ -1,47 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm_core.prompts import process_prompt
|
||||
from openllm.utils import generate_labels
|
||||
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
else: transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FlaxOPT(openllm.LLM['transformers.TFOPTForCausalLM', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
return bentoml.transformers.save_model(self.tag,
|
||||
transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs),
|
||||
custom_objects={'tokenizer': tokenizer},
|
||||
labels=generate_labels(self))
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'num_return_sequences': num_return_sequences,
|
||||
'repetition_penalty': repetition_penalty
|
||||
}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors='np'),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()).sequences,
|
||||
skip_special_tokens=True)
|
||||
@@ -1,24 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OPT(openllm.LLM['transformers.OPTForCausalLM', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
@@ -1,25 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm_core.utils import generate_labels
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class TFOPT(openllm.LLM['transformers.TFOPTForCausalLM', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
import transformers
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
return bentoml.transformers.save_model(self.tag,
|
||||
transformers.TFOPTForCausalLM.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, **attrs),
|
||||
custom_objects={'tokenizer': tokenizer},
|
||||
labels=generate_labels(self))
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors='tf'),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
@@ -1,26 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
from openllm_core.prompts import process_prompt
|
||||
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMOPT(openllm.LLM['vllm.LLMEngine', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'num_return_sequences': num_return_sequences
|
||||
}, {}
|
||||
@@ -1,36 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule
|
||||
from openllm.utils import is_torch_available
|
||||
from openllm.utils import is_vllm_available
|
||||
from openllm_core.config.configuration_stablelm import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_stablelm import START_STABLELM_COMMAND_DOCSTRING as START_STABLELM_COMMAND_DOCSTRING
|
||||
from openllm_core.config.configuration_stablelm import StableLMConfig as StableLMConfig
|
||||
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_stablelm'] = ['StableLM']
|
||||
if t.TYPE_CHECKING: from .modeling_stablelm import StableLM as StableLM
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_vllm_stablelm'] = ['VLLMStableLM']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_stablelm import VLLMStableLM as VLLMStableLM
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
extra_objects={
|
||||
'DEFAULT_PROMPT_TEMPLATE': DEFAULT_PROMPT_TEMPLATE,
|
||||
'START_STABLELM_COMMAND_DOCSTRING': START_STABLELM_COMMAND_DOCSTRING,
|
||||
'StableLMConfig': StableLMConfig,
|
||||
})
|
||||
@@ -1,26 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
|
||||
class StableLM(openllm.LLM['transformers.GPTNeoXForCausalLM', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
return [
|
||||
self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
stopping_criteria=openllm.StoppingCriteriaList([openllm.StopOnTokens()]))[0],
|
||||
skip_special_tokens=True)
|
||||
]
|
||||
@@ -1,10 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMStableLM(openllm.LLM['vllm.LLMEngine', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
@@ -1,36 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from openllm.exceptions import MissingDependencyError
|
||||
from openllm.utils import LazyModule
|
||||
from openllm.utils import is_torch_available
|
||||
from openllm.utils import is_vllm_available
|
||||
from openllm_core.config.configuration_starcoder import DEFAULT_PROMPT_TEMPLATE as DEFAULT_PROMPT_TEMPLATE
|
||||
from openllm_core.config.configuration_starcoder import START_STARCODER_COMMAND_DOCSTRING as START_STARCODER_COMMAND_DOCSTRING
|
||||
from openllm_core.config.configuration_starcoder import StarCoderConfig as StarCoderConfig
|
||||
|
||||
_import_structure: dict[str, list[str]] = {}
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_starcoder'] = ['StarCoder']
|
||||
if t.TYPE_CHECKING: from .modeling_starcoder import StarCoder as StarCoder
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_vllm_starcoder'] = ['VLLMStarCoder']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_starcoder import VLLMStarCoder as VLLMStarCoder
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
extra_objects={
|
||||
'DEFAULT_PROMPT_TEMPLATE': DEFAULT_PROMPT_TEMPLATE,
|
||||
'START_STARCODER_COMMAND_DOCSTRING': START_STARCODER_COMMAND_DOCSTRING,
|
||||
'StarCoderConfig': StarCoderConfig,
|
||||
})
|
||||
@@ -1,32 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm.utils import generate_labels
|
||||
from openllm_core.config.configuration_starcoder import EOD
|
||||
from openllm_core.config.configuration_starcoder import FIM_MIDDLE
|
||||
from openllm_core.config.configuration_starcoder import FIM_PAD
|
||||
from openllm_core.config.configuration_starcoder import FIM_PREFIX
|
||||
from openllm_core.config.configuration_starcoder import FIM_SUFFIX
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
class StarCoder(openllm.LLM['transformers.GPTBigCodeForCausalLM', 'transformers.GPT2TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
import torch
|
||||
import transformers
|
||||
torch_dtype, device_map = attrs.pop('torch_dtype', torch.float16), attrs.pop('device_map', 'auto')
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.add_special_tokens({'additional_special_tokens': [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], 'pad_token': EOD})
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs)
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={'tokenizer': tokenizer}, labels=generate_labels(self))
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
@@ -1,10 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
class VLLMStarCoder(openllm.LLM['vllm.LLMEngine', 'transformers.GPT2TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
@@ -56,13 +56,14 @@ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
model, tokenizer = openllm.AutoLLM.for_model("falcon", model_id=model_args.model_id, quantize="int4", bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16,
|
||||
ensure_available=True).prepare_for_training(adapter_type="lora",
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=16,
|
||||
bias="none",
|
||||
target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"])
|
||||
llm = openllm.LLM(model_args.model_id, quantize="int4", bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)
|
||||
llm.save_pretrained()
|
||||
model, tokenizer = llm.prepare_for_training(adapter_type="lora",
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=16,
|
||||
bias="none",
|
||||
target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"])
|
||||
model.config.use_cache = False
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import argparse
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import asyncio
|
||||
import openllm
|
||||
|
||||
openllm.utils.configure_logging()
|
||||
@@ -11,45 +12,36 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_NEW_TOKENS = 384
|
||||
|
||||
Q = "Answer the following question, step by step:\n{q}\nA:"
|
||||
question = "What is the meaning of life?"
|
||||
Q = 'Answer the following question, step by step:\n{q}\nA:'
|
||||
question = 'What is the meaning of life?'
|
||||
|
||||
def main() -> int:
|
||||
async def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("question", default=question)
|
||||
parser.add_argument('question', default=question)
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
args = parser.parse_args(args=[question])
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
model = openllm.AutoLLM.for_model("opt", model_id="facebook/opt-2.7b", ensure_available=True)
|
||||
llm = openllm.LLM[t.Any, t.Any]('facebook/opt-2.7b')
|
||||
prompt = Q.format(q=args.question)
|
||||
|
||||
logger.info("-" * 50, "Running with 'generate()'", "-" * 50)
|
||||
res = model.generate(prompt, max_new_tokens=MAX_NEW_TOKENS)
|
||||
logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res))
|
||||
logger.info('-' * 50, "Running with 'generate()'", '-' * 50)
|
||||
res = await llm.generate(prompt)
|
||||
logger.info('=' * 10, 'Response:', res)
|
||||
|
||||
logger.info("-" * 50, "Running with 'generate()' with per-requests argument", "-" * 50)
|
||||
res = model.generate(prompt, num_return_sequences=3)
|
||||
logger.info("=" * 10, "Response:", model.postprocess_generate(prompt, res))
|
||||
|
||||
logger.info("-" * 50, "Using Runner abstraction with runner.generate.run()", "-" * 50)
|
||||
r = openllm.Runner("opt", model_id="facebook/opt-350m", init_local=True)
|
||||
res = r.generate.run(prompt)
|
||||
logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res))
|
||||
|
||||
logger.info("-" * 50, "Using Runner abstraction with runner()", "-" * 50)
|
||||
res = r(prompt)
|
||||
logger.info("=" * 10, "Response:", r.llm.postprocess_generate(prompt, res))
|
||||
logger.info('-' * 50, "Running with 'generate()' with per-requests argument", '-' * 50)
|
||||
res = await llm.generate(prompt, max_new_tokens=MAX_NEW_TOKENS)
|
||||
logger.info('=' * 10, 'Response:', res)
|
||||
|
||||
return 0
|
||||
|
||||
def _mp_fn(index: t.Any): # noqa # type: ignore
|
||||
def _mp_fn(index: t.Any): # type: ignore
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
asyncio.run(main())
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
main()
|
||||
await main()
|
||||
else:
|
||||
raise SystemExit(main())
|
||||
raise SystemExit(asyncio.run(main()))
|
||||
|
||||
@@ -111,15 +111,7 @@ def prepare_for_int4_training(model_id: str,
|
||||
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
llm = openllm.AutoLLM.for_model("llama",
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
ensure_available=True,
|
||||
quantize="int4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
use_cache=not gradient_checkpointing,
|
||||
device_map="auto",
|
||||
)
|
||||
llm = openllm.LLM(model_id, revision=model_version, quantize="int4", bnb_4bit_compute_dtype=torch.bfloat16, use_cache=not gradient_checkpointing, device_map="auto")
|
||||
print("Model summary:", llm.model)
|
||||
|
||||
# get lora target modules
|
||||
@@ -185,8 +177,7 @@ def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
trainer = transformers.Trainer(model=model,
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
train_dataset=datasets,
|
||||
data_collator=transformers.default_data_collator,
|
||||
)
|
||||
data_collator=transformers.default_data_collator)
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
@@ -30,8 +30,7 @@ def load_trainer(model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, da
|
||||
return transformers.Trainer(model=model,
|
||||
train_dataset=dataset_dict["train"],
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
@@ -56,12 +55,9 @@ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
model, tokenizer = openllm.AutoLLM.for_model("opt", model_id=model_args.model_id, quantize="int8", ensure_available=True).prepare_for_training(adapter_type="lora",
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none")
|
||||
llm = openllm.LLM(model_args.model_id, quantize="int8")
|
||||
llm.save_pretrained()
|
||||
model, tokenizer = llm.prepare_for_training(adapter_type="lora", r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
|
||||
|
||||
# ft on english_quotes
|
||||
data = load_dataset("Abirate/english_quotes")
|
||||
|
||||
18
openllm-python/src/openllm/protocol/hf.py
Normal file
18
openllm-python/src/openllm/protocol/hf.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
|
||||
@attr.define
|
||||
class AgentRequest:
|
||||
inputs: str
|
||||
parameters: t.Dict[str, t.Any]
|
||||
|
||||
@attr.define
|
||||
class AgentResponse:
|
||||
generated_text: str
|
||||
|
||||
@attr.define
|
||||
class AgentErrorResponse:
|
||||
error_code: int
|
||||
message: str
|
||||
@@ -6,7 +6,15 @@ import attr
|
||||
|
||||
import openllm_core
|
||||
|
||||
from openllm import _conversation
|
||||
from openllm_core.utils import converter
|
||||
|
||||
@attr.define
|
||||
class ErrorResponse:
|
||||
message: str
|
||||
type: str
|
||||
object: str = 'error'
|
||||
param: t.Optional[str] = None
|
||||
code: t.Optional[str] = None
|
||||
|
||||
@attr.define
|
||||
class CompletionRequest:
|
||||
@@ -15,7 +23,7 @@ class CompletionRequest:
|
||||
suffix: t.Optional[str] = attr.field(default=None)
|
||||
max_tokens: t.Optional[int] = attr.field(default=16)
|
||||
temperature: t.Optional[float] = attr.field(default=1.0)
|
||||
top_p: t.Optional[float] = attr.field(default=1)
|
||||
top_p: t.Optional[float] = attr.field(default=1.0)
|
||||
n: t.Optional[int] = attr.field(default=1)
|
||||
stream: t.Optional[bool] = attr.field(default=False)
|
||||
logprobs: t.Optional[int] = attr.field(default=None)
|
||||
@@ -23,9 +31,11 @@ class CompletionRequest:
|
||||
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
|
||||
presence_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
best_of: t.Optional[int] = attr.field(default=1)
|
||||
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
|
||||
user: t.Optional[str] = attr.field(default=None)
|
||||
# supported by vLLM and us
|
||||
top_k: t.Optional[int] = attr.field(default=None)
|
||||
best_of: t.Optional[int] = attr.field(default=1)
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionRequest:
|
||||
@@ -33,16 +43,19 @@ class ChatCompletionRequest:
|
||||
model: str = attr.field(default=None)
|
||||
functions: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list))
|
||||
function_calls: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list))
|
||||
temperature: t.Optional[float] = attr.field(default=1.0)
|
||||
top_p: t.Optional[float] = attr.field(default=1)
|
||||
n: t.Optional[int] = attr.field(default=1)
|
||||
temperature: t.Optional[float] = attr.field(default=None)
|
||||
top_p: t.Optional[float] = attr.field(default=None)
|
||||
n: t.Optional[int] = attr.field(default=None)
|
||||
stream: t.Optional[bool] = attr.field(default=False)
|
||||
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
|
||||
max_tokens: t.Optional[int] = attr.field(default=None)
|
||||
presence_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
presence_penalty: t.Optional[float] = attr.field(default=None)
|
||||
frequency_penalty: t.Optional[float] = attr.field(default=None)
|
||||
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
|
||||
user: t.Optional[str] = attr.field(default=None)
|
||||
# supported by vLLM and us
|
||||
top_k: t.Optional[int] = attr.field(default=None)
|
||||
best_of: t.Optional[int] = attr.field(default=1)
|
||||
|
||||
@attr.define
|
||||
class LogProbs:
|
||||
@@ -52,80 +65,90 @@ class LogProbs:
|
||||
top_logprobs: t.List[t.Dict[str, t.Any]] = attr.field(default=attr.Factory(list))
|
||||
|
||||
@attr.define
|
||||
class CompletionTextChoice:
|
||||
text: str
|
||||
index: int
|
||||
logprobs: LogProbs = attr.field(default=attr.Factory(lambda: LogProbs()))
|
||||
finish_reason: str = attr.field(default=None)
|
||||
|
||||
@attr.define
|
||||
class Usage:
|
||||
class UsageInfo:
|
||||
prompt_tokens: int = attr.field(default=0)
|
||||
completion_tokens: int = attr.field(default=0)
|
||||
total_tokens: int = attr.field(default=0)
|
||||
|
||||
@attr.define
|
||||
class CompletionResponse:
|
||||
choices: t.List[CompletionTextChoice]
|
||||
class CompletionResponseChoice:
|
||||
index: int
|
||||
text: str
|
||||
logprobs: t.Optional[LogProbs] = None
|
||||
finish_reason: t.Optional[str] = None
|
||||
|
||||
@attr.define
|
||||
class CompletionResponseStreamChoice:
|
||||
index: int
|
||||
text: str
|
||||
logprobs: t.Optional[LogProbs] = None
|
||||
finish_reason: t.Optional[str] = None
|
||||
|
||||
@attr.define
|
||||
class CompletionStreamResponse:
|
||||
model: str
|
||||
choices: t.List[CompletionResponseStreamChoice]
|
||||
object: str = 'text_completion'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
usage: Usage = attr.field(default=attr.Factory(lambda: Usage()))
|
||||
|
||||
@attr.define
|
||||
class CompletionResponseStream:
|
||||
choices: t.List[CompletionTextChoice]
|
||||
class CompletionResponse:
|
||||
choices: t.List[CompletionResponseChoice]
|
||||
model: str
|
||||
usage: UsageInfo
|
||||
object: str = 'text_completion'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
|
||||
LiteralRole = t.Literal['system', 'user', 'assistant']
|
||||
|
||||
class Message(t.TypedDict):
|
||||
role: LiteralRole
|
||||
content: str
|
||||
|
||||
@attr.define
|
||||
class Delta:
|
||||
role: t.Optional[LiteralRole] = None
|
||||
content: t.Optional[str] = None
|
||||
|
||||
@attr.define
|
||||
class ChatMessage:
|
||||
role: LiteralRole
|
||||
content: str
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionChoice:
|
||||
index: int
|
||||
message: Message
|
||||
finish_reason: str = attr.field(default=None)
|
||||
converter.register_unstructure_hook(ChatMessage, lambda msg: {'role': msg.role, 'content': msg.content})
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionStreamChoice:
|
||||
class ChatCompletionResponseStreamChoice:
|
||||
index: int
|
||||
delta: Message
|
||||
finish_reason: str = attr.field(default=None)
|
||||
delta: Delta
|
||||
finish_reason: t.Optional[str] = attr.field(default=None)
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponseChoice:
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: t.Optional[str] = attr.field(default=None)
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponse:
|
||||
choices: t.List[ChatCompletionChoice]
|
||||
choices: t.List[ChatCompletionResponseChoice]
|
||||
model: str
|
||||
object: str = 'chat.completion'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
|
||||
usage: Usage = attr.field(default=attr.Factory(lambda: Usage()))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
usage: UsageInfo = attr.field(default=attr.Factory(lambda: UsageInfo()))
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponseStream:
|
||||
choices: t.List[ChatCompletionStreamChoice]
|
||||
class ChatCompletionStreamResponse:
|
||||
choices: t.List[ChatCompletionResponseStreamChoice]
|
||||
model: str
|
||||
object: str = 'chat.completion.chunk'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
|
||||
@attr.define
|
||||
class ModelCard:
|
||||
id: str
|
||||
object: str = 'model'
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
owned_by: str = 'na'
|
||||
|
||||
@attr.define
|
||||
@@ -133,10 +156,14 @@ class ModelList:
|
||||
object: str = 'list'
|
||||
data: t.List[ModelCard] = attr.field(factory=list)
|
||||
|
||||
def messages_to_prompt(messages: list[Message], model: str, llm_config: openllm_core.LLMConfig) -> str:
|
||||
conv_template = _conversation.get_conv_template(model, llm_config)
|
||||
for message in messages:
|
||||
if message['role'] == 'system': conv_template.set_system_message(message['content'])
|
||||
else: conv_template.append_message(message['role'], message['content'])
|
||||
conv_template.append_message('assistant', '')
|
||||
return conv_template.get_prompt()
|
||||
async def get_conversation_prompt(request: ChatCompletionRequest, llm_config: openllm_core.LLMConfig) -> str:
|
||||
conv = llm_config.get_conversation_template()
|
||||
for message in request.messages:
|
||||
msg_role = message['role']
|
||||
if msg_role == 'system': conv.set_system_message(message['content'])
|
||||
elif msg_role == 'user': conv.append_message(conv.roles[0], message['content'])
|
||||
elif msg_role == 'assistant': conv.append_message(conv.roles[1], message['content'])
|
||||
else: raise ValueError(f'Unknown role: {msg_role}')
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], '')
|
||||
return conv.get_prompt()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'''Serialisation utilities for OpenLLM.
|
||||
|
||||
Currently supports transformers for PyTorch, Tensorflow and Flax.
|
||||
Currently supports transformers for PyTorch, and vLLM.
|
||||
|
||||
Currently, GGML format is working in progress.
|
||||
'''
|
||||
@@ -19,11 +19,15 @@ from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers as _transformers
|
||||
|
||||
import bentoml
|
||||
|
||||
from . import constants as constants
|
||||
from . import ggml as ggml
|
||||
from . import transformers as transformers
|
||||
else:
|
||||
_transformers = openllm.utils.LazyLoader('_transformers', globals(), 'transformers')
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
@@ -33,12 +37,11 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
'''
|
||||
from .transformers._helpers import infer_tokenizers_from_llm
|
||||
from .transformers._helpers import process_config
|
||||
|
||||
config, *_ = process_config(llm._bentomodel.path, llm.trust_remote_code)
|
||||
config, *_ = process_config(llm.bentomodel.path, llm.trust_remote_code)
|
||||
|
||||
bentomodel_fs = fs.open_fs(llm._bentomodel.path)
|
||||
bentomodel_fs = fs.open_fs(llm.bentomodel.path)
|
||||
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
|
||||
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, 'rb') as cofile:
|
||||
try:
|
||||
@@ -47,7 +50,7 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
raise openllm.exceptions.OpenLLMException("Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
|
||||
"For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\"") from None
|
||||
else:
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs)
|
||||
tokenizer = _transformers.AutoTokenizer.from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if config.pad_token_id is not None: tokenizer.pad_token_id = config.pad_token_id
|
||||
@@ -66,7 +69,7 @@ def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "tf", "flax", "vllm")'
|
||||
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")'
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.__llm_backend__="ggml"'
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'),
|
||||
'tf': ('TFAutoModelForCausalLM', 'TFAutoModelForSeq2SeqLM'),
|
||||
'flax': ('FlaxAutoModelForCausalLM', 'FlaxAutoModelForSeq2SeqLM'),
|
||||
'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')
|
||||
}
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'), 'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')}
|
||||
HUB_ATTRS = ['cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision', 'subfolder', 'use_auth_token']
|
||||
CONFIG_FILE_NAME = 'config.json'
|
||||
# the below is similar to peft.utils.other.CONFIG_NAME
|
||||
PEFT_CONFIG_NAME = 'adapter_config.json'
|
||||
|
||||
@@ -4,10 +4,10 @@ import importlib
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import orjson
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from packaging.version import Version
|
||||
from simple_di import Provide
|
||||
from simple_di import inject
|
||||
|
||||
@@ -16,13 +16,13 @@ import openllm
|
||||
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelOptions
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
from ._helpers import check_unintialised_params
|
||||
from ._helpers import get_hash
|
||||
from ._helpers import infer_autoclass_from_llm
|
||||
from ._helpers import infer_tokenizers_from_llm
|
||||
from ._helpers import make_model_signatures
|
||||
from ._helpers import process_config
|
||||
from .weights import HfIgnore
|
||||
|
||||
@@ -32,16 +32,30 @@ if t.TYPE_CHECKING:
|
||||
import auto_gptq as autogptq
|
||||
import torch
|
||||
import torch.nn
|
||||
import transformers
|
||||
|
||||
from bentoml._internal.models import ModelStore
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
else:
|
||||
transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
autogptq = openllm.utils.LazyLoader('autogptq', globals(), 'auto_gptq')
|
||||
torch = openllm.utils.LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['import_model', 'get', 'load_model']
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None) -> None:
|
||||
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
|
||||
if not llm._local:
|
||||
try:
|
||||
if _revision is None: _revision = get_hash(config)
|
||||
except ValueError:
|
||||
pass
|
||||
if llm._tag.version is None: _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
|
||||
else: _revision = llm._tag.version
|
||||
if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
|
||||
|
||||
@inject
|
||||
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
|
||||
@@ -49,7 +63,7 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
|
||||
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first,
|
||||
returning all of the unused kwargs.
|
||||
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
|
||||
For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion.
|
||||
|
||||
Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`.
|
||||
@@ -57,20 +71,22 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
Refer to Transformers documentation for more information about kwargs.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance for this given model.
|
||||
trust_remote_code: Whether to trust the remote code when loading the model.
|
||||
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
llm: The LLM instance for this given model.
|
||||
trust_remote_code: Whether to trust the remote code when loading the model.
|
||||
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
|
||||
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
|
||||
"""
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_patch_correct_tag(llm, config)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize = llm._quantize
|
||||
quantize = llm._quantise
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
|
||||
if quantize: metadata['_quantize'] = quantize
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
if not architectures: raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
|
||||
metadata['_pretrained_class'] = architectures[0]
|
||||
metadata['_revision'] = get_hash(config)
|
||||
|
||||
signatures: DictStrAny = {}
|
||||
|
||||
@@ -79,26 +95,24 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
signatures['generate'] = {'batchable': False}
|
||||
else:
|
||||
# this model might be called with --quantize int4, therefore we need to pop this out
|
||||
# since saving int4 is not yet supported
|
||||
if 'quantization_config' in attrs and getattr(attrs['quantization_config'], 'load_in_4bit', False):
|
||||
attrs.pop('quantization_config')
|
||||
if llm.__llm_backend__ != 'flax': attrs['use_safetensors'] = safe_serialisation
|
||||
attrs['use_safetensors'] = safe_serialisation
|
||||
metadata['_framework'] = llm.__llm_backend__
|
||||
signatures.update(make_model_signatures(llm))
|
||||
signatures.update({
|
||||
k: ModelSignature(batchable=False)
|
||||
for k in ('__call__', 'forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search')
|
||||
})
|
||||
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = None
|
||||
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
|
||||
imported_modules: list[types.ModuleType] = []
|
||||
bentomodel = bentoml.Model.create(llm.tag,
|
||||
module='openllm.serialisation.transformers',
|
||||
api_version='v2',
|
||||
api_version='v2.1.0',
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
@@ -108,13 +122,12 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
try:
|
||||
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
if llm._quantise or llm._quantization_config: attrs['quantization_config'] = llm.quantization_config
|
||||
if quantize == 'gptq':
|
||||
from optimum.gptq.constants import GPTQ_CONFIG
|
||||
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
|
||||
if llm._local:
|
||||
# possible local path
|
||||
logger.debug('Model will be loaded into memory to save to target store as it is from local path.')
|
||||
if llm._local: # possible local path
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
@@ -133,6 +146,7 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
# NOTE: We need to free up the cache after importing the model
|
||||
# in the case where users first run openllm start without the model available locally.
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
del model
|
||||
return bentomodel
|
||||
|
||||
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
@@ -145,31 +159,35 @@ def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
'''
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if Version(model.info.api_version) < Version('v2'):
|
||||
raise openllm.exceptions.OpenLLMException('Please run "openllm prune -y --include-bentos" and upgrade all saved model to latest release.')
|
||||
if model.info.labels['backend'] != llm.__llm_backend__:
|
||||
raise openllm.exceptions.OpenLLMException(f"Model {model.tag} was saved with backend {model.info.labels['backend']}, while loading with {llm.__llm_backend__}.")
|
||||
backend = model.info.labels['backend']
|
||||
if backend != llm.__llm_backend__: raise openllm.exceptions.OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
|
||||
_patch_correct_tag(llm, process_config(model.path, llm.trust_remote_code)[0], _revision=t.cast(t.Optional[str], model.info.metadata.get('_revision')))
|
||||
return model
|
||||
except Exception as err:
|
||||
if auto_import: return import_model(llm, trust_remote_code=llm.trust_remote_code)
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, llm.trust_remote_code, **attrs)
|
||||
config, hub_attrs, attrs = process_config(llm.bentomodel.path, llm.trust_remote_code, **attrs)
|
||||
_patch_correct_tag(llm, config, _revision=t.cast(t.Optional[str], llm.bentomodel.info.metadata.get('_revision')))
|
||||
auto_class = infer_autoclass_from_llm(llm, config)
|
||||
device_map: str | None = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
|
||||
if llm._quantise or llm._quantization_config: attrs['quantization_config'] = llm.quantization_config
|
||||
|
||||
if '_quantize' in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata['_quantize'] == 'gptq':
|
||||
if '_quantize' in llm.bentomodel.info.metadata and llm.bentomodel.info.metadata['_quantize'] == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
|
||||
model = auto_class.from_pretrained(llm._bentomodel.path, device_map='auto', **hub_attrs, **attrs)
|
||||
try:
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=True, **hub_attrs, **attrs)
|
||||
except Exception as err:
|
||||
logger.debug("Exception caught while trying to load with 'flash_attention_2': %s", err)
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=False, **hub_attrs, **attrs)
|
||||
# XXX: Use the below logic once TheBloke finished migration to new GPTQConfig from transformers
|
||||
# Seems like the logic below requires to add support for safetensors on accelerate
|
||||
#
|
||||
# from accelerate import init_empty_weights
|
||||
# from optimum.gptq import load_quantized_model
|
||||
# # disable exllama if gptq is loaded on CPU
|
||||
@@ -179,6 +197,6 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
# empty.tie_weights()
|
||||
# model = load_quantized_model(empty, save_folder=llm._bentomodel.path, device_map='auto', disable_exllama=disable_exllama)
|
||||
else:
|
||||
model = auto_class.from_pretrained(llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **hub_attrs, **attrs).eval()
|
||||
if llm.__llm_backend__ in {'pt', 'vllm'}: check_unintialised_params(model)
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **hub_attrs, **attrs).eval()
|
||||
if llm.__llm_backend__ == 'pt': check_unintialised_params(model)
|
||||
return t.cast('M', model)
|
||||
|
||||
@@ -5,7 +5,6 @@ import typing as t
|
||||
import openllm
|
||||
import openllm_core
|
||||
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
|
||||
from openllm.serialisation.constants import HUB_ATTRS
|
||||
|
||||
@@ -15,13 +14,17 @@ if t.TYPE_CHECKING:
|
||||
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from bentoml._internal.models.model import ModelSignaturesType
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
else:
|
||||
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
def get_hash(config: transformers.PretrainedConfig) -> str:
|
||||
_commit_hash = getattr(config, '_commit_hash', None)
|
||||
if _commit_hash is None: raise ValueError(f'Cannot find commit hash in {config}')
|
||||
return _commit_hash
|
||||
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
@@ -42,17 +45,11 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
|
||||
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
|
||||
if __cls is None:
|
||||
raise ValueError(f'Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`')
|
||||
return __cls
|
||||
|
||||
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
|
||||
if llm.config['trust_remote_code']:
|
||||
if llm.trust_remote_code:
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if not hasattr(config, 'auto_map'):
|
||||
raise ValueError(f'Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
raise ValueError(f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map: autoclass = 'AutoModel'
|
||||
@@ -67,15 +64,3 @@ def check_unintialised_params(model: torch.nn.Module) -> None:
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0:
|
||||
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
# NOTE: sync with bentoml/_internal/frameworks/transformers.py#make_default_signatures
|
||||
def make_model_signatures(llm: openllm.LLM[M, T]) -> ModelSignaturesType:
|
||||
infer_fn: tuple[str, ...] = ('__call__',)
|
||||
default_config = ModelSignature(batchable=False)
|
||||
if llm.__llm_backend__ in {'pt', 'vllm'}:
|
||||
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search',)
|
||||
elif llm.__llm_backend__ == 'tf':
|
||||
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search', 'contrastive_search',)
|
||||
else:
|
||||
infer_fn += ('generate',)
|
||||
return {k: default_config for k in infer_fn}
|
||||
|
||||
@@ -24,19 +24,11 @@ class HfIgnore:
|
||||
|
||||
@classmethod
|
||||
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
|
||||
if llm.__llm_backend__ == 'vllm':
|
||||
if llm.__llm_backend__ in {'vllm', 'pt'}:
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
|
||||
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
|
||||
else: base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'tf': base = [cls.flax, cls.pt, cls.gguf]
|
||||
elif llm.__llm_backend__ == 'flax':
|
||||
base = [cls.tf, cls.pt, cls.safetensors, cls.gguf] # as of current, safetensors is not supported with flax
|
||||
elif llm.__llm_backend__ == 'pt':
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
|
||||
else: base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'ggml':
|
||||
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
||||
elif llm.__llm_backend__ == 'ggml': base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
||||
else:
|
||||
raise ValueError('Unknown backend (should never happen at all.)')
|
||||
# filter out these files, since we probably don't need them for now.
|
||||
|
||||
@@ -42,15 +42,15 @@ def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | N
|
||||
|
||||
@contextlib.contextmanager
|
||||
def prepare(model: str,
|
||||
model_id: str | None = None,
|
||||
implementation: LiteralBackend = 'pt',
|
||||
model_id: str,
|
||||
backend: LiteralBackend = 'pt',
|
||||
deployment_mode: t.Literal['container', 'local'] = 'local',
|
||||
clean_context: contextlib.ExitStack | None = None,
|
||||
cleanup: bool = True) -> t.Iterator[str]:
|
||||
if clean_context is None:
|
||||
clean_context = contextlib.ExitStack()
|
||||
cleanup = True
|
||||
llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True)
|
||||
llm = openllm.LLM[t.Any, t.Any](model_id, backend=backend)
|
||||
bento_tag = bentoml.Tag.from_taglike(f'{llm.llm_type}-service:{llm.tag.version}')
|
||||
if not bentoml.list(bento_tag):
|
||||
bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
|
||||
|
||||
@@ -8,28 +8,13 @@ import typing as t
|
||||
|
||||
import openllm_core
|
||||
|
||||
from . import dummy_flax_objects as dummy_flax_objects
|
||||
from . import dummy_pt_objects as dummy_pt_objects
|
||||
from . import dummy_tf_objects as dummy_tf_objects
|
||||
from . import dummy_vllm_objects as dummy_vllm_objects
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
|
||||
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> dict[str, t.Any]:
|
||||
return {'backend': llm.__llm_backend__, 'framework': 'openllm', 'model_name': llm.config['model_name'], 'architecture': llm.config['architecture'], 'serialisation': llm._serialisation}
|
||||
|
||||
def infer_auto_class(backend: LiteralBackend) -> type[openllm.AutoLLM | openllm.AutoTFLLM | openllm.AutoFlaxLLM | openllm.AutoVLLM]:
|
||||
import openllm
|
||||
if backend == 'tf': return openllm.AutoTFLLM
|
||||
elif backend == 'flax': return openllm.AutoFlaxLLM
|
||||
elif backend == 'pt': return openllm.AutoLLM
|
||||
elif backend == 'vllm': return openllm.AutoVLLM
|
||||
else: raise RuntimeError(f"Unknown backend: {backend} (supported: 'pt', 'flax', 'tf', 'vllm')")
|
||||
|
||||
__all__ = ['generate_labels', 'infer_auto_class', 'dummy_flax_objects', 'dummy_pt_objects', 'dummy_tf_objects', 'dummy_vllm_objects']
|
||||
__all__ = ['generate_labels']
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
# This file is generated by tools/update-dummy.py. DO NOT EDIT MANUALLY!
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class FlaxFlanT5(metaclass=_DummyMetaclass):
|
||||
_backends=["flax"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["flax"])
|
||||
class FlaxOPT(metaclass=_DummyMetaclass):
|
||||
_backends=["flax"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["flax"])
|
||||
class AutoFlaxLLM(metaclass=_DummyMetaclass):
|
||||
_backends=["flax"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["flax"])
|
||||
MODEL_FLAX_MAPPING_NAMES:_t.Any=None
|
||||
__all__:list[str]=["MODEL_FLAX_MAPPING_NAMES","AutoFlaxLLM","FlaxFlanT5","FlaxOPT"]
|
||||
43
openllm-python/src/openllm/utils/dummy_pt_objects.py
generated
43
openllm-python/src/openllm/utils/dummy_pt_objects.py
generated
@@ -1,43 +0,0 @@
|
||||
# This file is generated by tools/update-dummy.py. DO NOT EDIT MANUALLY!
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class ChatGLM(metaclass=_DummyMetaclass):
|
||||
_backends=["torch","cpm_kernels","sentencepiece"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","cpm_kernels","sentencepiece"])
|
||||
class DollyV2(metaclass=_DummyMetaclass):
|
||||
_backends=["torch"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch"])
|
||||
class Falcon(metaclass=_DummyMetaclass):
|
||||
_backends=["torch","einops","xformers"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","einops","xformers"])
|
||||
class FlanT5(metaclass=_DummyMetaclass):
|
||||
_backends=["torch"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch"])
|
||||
class GPTNeoX(metaclass=_DummyMetaclass):
|
||||
_backends=["torch"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch"])
|
||||
class Llama(metaclass=_DummyMetaclass):
|
||||
_backends=["torch","fairscale","sentencepiece","scipy"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","fairscale","sentencepiece","scipy"])
|
||||
class MPT(metaclass=_DummyMetaclass):
|
||||
_backends=["torch","triton","einops"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","triton","einops"])
|
||||
class OPT(metaclass=_DummyMetaclass):
|
||||
_backends=["torch"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch"])
|
||||
class StableLM(metaclass=_DummyMetaclass):
|
||||
_backends=["torch"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch"])
|
||||
class StarCoder(metaclass=_DummyMetaclass):
|
||||
_backends=["torch","bitsandbytes"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","bitsandbytes"])
|
||||
class Baichuan(metaclass=_DummyMetaclass):
|
||||
_backends=["torch","cpm_kernels","sentencepiece"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","cpm_kernels","sentencepiece"])
|
||||
class AutoLLM(metaclass=_DummyMetaclass):
|
||||
_backends=["torch"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch"])
|
||||
MODEL_MAPPING_NAMES:_t.Any=None
|
||||
__all__:list[str]=["MODEL_MAPPING_NAMES","AutoLLM","ChatGLM","DollyV2","Falcon","FlanT5","GPTNeoX","Llama","MPT","OPT","StableLM","StarCoder","Baichuan"]
|
||||
16
openllm-python/src/openllm/utils/dummy_tf_objects.py
generated
16
openllm-python/src/openllm/utils/dummy_tf_objects.py
generated
@@ -1,16 +0,0 @@
|
||||
# This file is generated by tools/update-dummy.py. DO NOT EDIT MANUALLY!
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class TFFlanT5(metaclass=_DummyMetaclass):
|
||||
_backends=["tensorflow"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["tensorflow"])
|
||||
class TFOPT(metaclass=_DummyMetaclass):
|
||||
_backends=["tensorflow"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["tensorflow"])
|
||||
class AutoTFLLM(metaclass=_DummyMetaclass):
|
||||
_backends=["tensorflow"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["tensorflow"])
|
||||
MODEL_TF_MAPPING_NAMES:_t.Any=None
|
||||
__all__:list[str]=["MODEL_TF_MAPPING_NAMES","AutoTFLLM","TFFlanT5","TFOPT"]
|
||||
@@ -1,37 +0,0 @@
|
||||
# This file is generated by tools/update-dummy.py. DO NOT EDIT MANUALLY!
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class VLLMBaichuan(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm","cpm_kernels","sentencepiece"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","cpm_kernels","sentencepiece"])
|
||||
class VLLMDollyV2(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm"])
|
||||
class VLLMFalcon(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm","einops","xformers"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","einops","xformers"])
|
||||
class VLLMGPTNeoX(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm"])
|
||||
class VLLMMPT(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm","triton","einops"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","triton","einops"])
|
||||
class VLLMOPT(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm"])
|
||||
class VLLMStableLM(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm"])
|
||||
class VLLMStarCoder(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm","bitsandbytes"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","bitsandbytes"])
|
||||
class VLLMLlama(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm","fairscale","sentencepiece","scipy"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","fairscale","sentencepiece","scipy"])
|
||||
class AutoVLLM(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm"])
|
||||
MODEL_VLLM_MAPPING_NAMES:_t.Any=None
|
||||
__all__:list[str]=["MODEL_VLLM_MAPPING_NAMES","AutoVLLM","VLLMBaichuan","VLLMDollyV2","VLLMFalcon","VLLMGPTNeoX","VLLMMPT","VLLMOPT","VLLMStableLM","VLLMStarCoder","VLLMLlama"]
|
||||
Reference in New Issue
Block a user