mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-04 15:16:03 -05:00
refactor: delete unused code (#716)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -3,14 +3,12 @@ import os as _os
|
||||
import pathlib as _pathlib
|
||||
import warnings as _warnings
|
||||
|
||||
import openllm_cli as _cli
|
||||
from openllm_cli import _sdk
|
||||
|
||||
from . import utils as utils
|
||||
|
||||
if utils.DEBUG:
|
||||
utils.set_debug_mode(True)
|
||||
utils.set_quiet_mode(False)
|
||||
_logging.basicConfig(level=_logging.NOTSET)
|
||||
else:
|
||||
# configuration for bitsandbytes before import
|
||||
@@ -47,18 +45,9 @@ __lazy = utils.LazyModule(
|
||||
'serialisation': ['ggml', 'transformers'],
|
||||
'_quantisation': ['infer_quantisation_config'],
|
||||
'_llm': ['LLM'],
|
||||
'_generation': [
|
||||
'StopSequenceCriteria',
|
||||
'StopOnTokens',
|
||||
'prepare_logits_processor',
|
||||
'get_context_length',
|
||||
'is_sentence_complete',
|
||||
'is_partial_stop',
|
||||
],
|
||||
},
|
||||
extra_objects={
|
||||
'COMPILED': COMPILED,
|
||||
'cli': _cli,
|
||||
'start': _sdk.start,
|
||||
'start_grpc': _sdk.start_grpc,
|
||||
'build': _sdk.build,
|
||||
|
||||
@@ -16,7 +16,6 @@ from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING,CONFIG_MAPPING_N
|
||||
# update-config-stubs.py: import stubs stop
|
||||
# fmt: on
|
||||
|
||||
import openllm_cli as _cli
|
||||
from openllm_cli._sdk import (
|
||||
build as build,
|
||||
import_model as import_model,
|
||||
@@ -44,14 +43,6 @@ from . import (
|
||||
utils as utils,
|
||||
)
|
||||
from ._deprecated import Runner as Runner
|
||||
from ._generation import (
|
||||
StopOnTokens as StopOnTokens,
|
||||
StopSequenceCriteria as StopSequenceCriteria,
|
||||
prepare_logits_processor as prepare_logits_processor,
|
||||
is_partial_stop as is_partial_stop,
|
||||
is_sentence_complete as is_sentence_complete,
|
||||
get_context_length as get_context_length,
|
||||
)
|
||||
from ._llm import LLM as LLM
|
||||
from ._quantisation import infer_quantisation_config as infer_quantisation_config
|
||||
from ._strategies import CascadingResourceStrategy as CascadingResourceStrategy, get_resource as get_resource
|
||||
@@ -60,5 +51,4 @@ from .entrypoints import mount_entrypoints as mount_entrypoints
|
||||
from .protocol import openai as openai
|
||||
from .serialisation import ggml as ggml, transformers as transformers
|
||||
|
||||
cli = _cli
|
||||
COMPILED: bool = ...
|
||||
|
||||
@@ -1,13 +1,2 @@
|
||||
"""CLI entrypoint for OpenLLM.
|
||||
|
||||
Usage:
|
||||
openllm --help
|
||||
|
||||
To start any OpenLLM model:
|
||||
openllm start <model_name> --options ...
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
from openllm_cli.entrypoint import cli
|
||||
|
||||
cli()
|
||||
# fmt: off
|
||||
if __name__ == '__main__':from openllm_cli.entrypoint import cli;cli() # noqa
|
||||
|
||||
@@ -6,7 +6,10 @@ import warnings
|
||||
|
||||
import openllm
|
||||
from openllm_core._typing_compat import LiteralBackend, ParamSpec
|
||||
from openllm_core.utils import first_not_none, is_vllm_available
|
||||
from openllm_core.utils import first_not_none, getenv, is_vllm_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ._runners import Runner as _Runner
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
@@ -20,7 +23,7 @@ def Runner(
|
||||
backend: LiteralBackend | None = None,
|
||||
llm_config: openllm.LLMConfig | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> openllm.LLMRunner[t.Any, t.Any]:
|
||||
) -> _Runner[t.Any, t.Any]:
|
||||
"""Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'.
|
||||
|
||||
> [!WARNING]
|
||||
@@ -73,9 +76,9 @@ def Runner(
|
||||
attrs.update(
|
||||
{
|
||||
'model_id': model_id,
|
||||
'quantize': os.getenv('OPENLLM_QUANTIZE', attrs.get('quantize', None)),
|
||||
'serialisation': first_not_none(
|
||||
attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']
|
||||
'quantize': getenv('QUANTIZE', var=['QUANTISE'], default=attrs.get('quantize', None)),
|
||||
'serialisation': getenv(
|
||||
'serialization', default=attrs.get('serialisation', llm_config['serialisation']), var=['SERIALISATION']
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,23 +1,6 @@
|
||||
import transformers
|
||||
|
||||
|
||||
class StopSequenceCriteria(transformers.StoppingCriteria):
|
||||
def __init__(self, stop_sequences, tokenizer):
|
||||
if isinstance(stop_sequences, str):
|
||||
stop_sequences = [stop_sequences]
|
||||
self.stop_sequences, self.tokenizer = stop_sequences, tokenizer
|
||||
|
||||
def __call__(self, input_ids, scores, **kwargs):
|
||||
return any(
|
||||
self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(transformers.StoppingCriteria):
|
||||
def __call__(self, input_ids, scores, **kwargs):
|
||||
return input_ids[0][-1] in {50278, 50279, 50277, 1, 0}
|
||||
|
||||
|
||||
def prepare_logits_processor(config):
|
||||
generation_config = config.generation_config
|
||||
logits_processor = transformers.LogitsProcessorList()
|
||||
|
||||
@@ -1,27 +1,7 @@
|
||||
from typing import Any, List, Union
|
||||
|
||||
from torch import FloatTensor, LongTensor
|
||||
from transformers import (
|
||||
LogitsProcessorList,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers import LogitsProcessorList, PretrainedConfig
|
||||
|
||||
from openllm_core import LLMConfig
|
||||
|
||||
Tokenizer = Union[PreTrainedTokenizerBase, PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
|
||||
class StopSequenceCriteria:
|
||||
stop_sequences: List[str]
|
||||
tokenizer: Tokenizer
|
||||
def __init__(self, stop_sequences: Union[str, List[str]], tokenizer: Tokenizer) -> None: ...
|
||||
def __call__(self, input_ids: LongTensor, scores: FloatTensor, **kwargs: Any) -> bool: ...
|
||||
|
||||
class StopOnTokens:
|
||||
def __call__(self, input_ids: LongTensor, scores: FloatTensor, **kwargs: Any) -> bool: ...
|
||||
|
||||
def prepare_logits_processor(config: LLMConfig) -> LogitsProcessorList: ...
|
||||
def get_context_length(config: PretrainedConfig) -> int: ...
|
||||
def is_sentence_complete(output: str) -> bool: ...
|
||||
|
||||
@@ -71,9 +71,7 @@ def normalise_model_name(name: str) -> str:
|
||||
|
||||
def _resolve_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
if not is_peft_available():
|
||||
raise RuntimeError(
|
||||
"LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
raise RuntimeError("Requires 'peft' to be installed. Do 'pip install \"openllm[fine-tune]\"'")
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
resolved: AdapterMap = {}
|
||||
@@ -285,8 +283,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
if env is not None:return str(env).upper() in ENV_VARS_TRUE_VALUES
|
||||
return self.__llm_trust_remote_code__
|
||||
@property
|
||||
def runner_name(self):return f"llm-{self.config['start_name']}-runner"
|
||||
@property
|
||||
def model_id(self):return self._model_id
|
||||
@property
|
||||
def revision(self):return self._revision
|
||||
|
||||
@@ -97,8 +97,6 @@ class LLM(Generic[M, T]):
|
||||
@property
|
||||
def trust_remote_code(self) -> bool: ...
|
||||
@property
|
||||
def runner_name(self) -> str: ...
|
||||
@property
|
||||
def model_id(self) -> str: ...
|
||||
@property
|
||||
def revision(self) -> str: ...
|
||||
|
||||
@@ -9,7 +9,6 @@ import torch
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm_core._schemas import CompletionChunk, GenerationOutput, SampleLogprobs
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import ReprMixin, is_ctranslate_available, is_vllm_available
|
||||
|
||||
__all__ = ['runner']
|
||||
@@ -28,12 +27,10 @@ def registry(cls=None, *, alias=None):
|
||||
|
||||
|
||||
def runner(llm: openllm.LLM):
|
||||
from ._strategies import CascadingResourceStrategy
|
||||
|
||||
try:
|
||||
models = [llm.bentomodel]
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
raise RuntimeError(f'Failed to locate {llm.bentomodel}:{err}') from err
|
||||
assert llm.bentomodel
|
||||
except (bentoml.exceptions.NotFound, AssertionError) as err:
|
||||
raise RuntimeError(f'Failed to locate {llm.bentomodel}: {err}') from err
|
||||
|
||||
return types.new_class(
|
||||
llm.config.__class__.__name__[:-6] + 'Runner',
|
||||
@@ -73,9 +70,9 @@ def runner(llm: openllm.LLM):
|
||||
),
|
||||
)(
|
||||
_registry[llm.__llm_backend__],
|
||||
name=llm.runner_name,
|
||||
models=models,
|
||||
scheduling_strategy=CascadingResourceStrategy,
|
||||
name=f"llm-{llm.config['start_name']}-runner",
|
||||
models=[llm.bentomodel],
|
||||
scheduling_strategy=openllm.CascadingResourceStrategy,
|
||||
runnable_init_params={'llm': llm},
|
||||
)
|
||||
|
||||
@@ -87,7 +84,7 @@ class CTranslateRunnable(bentoml.Runnable):
|
||||
|
||||
def __init__(self, llm):
|
||||
if not is_ctranslate_available():
|
||||
raise OpenLLMException('ctranslate is not installed. Please install it with `pip install "openllm[ctranslate]"`')
|
||||
raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
@@ -137,7 +134,7 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
|
||||
def __init__(self, llm):
|
||||
if not is_vllm_available():
|
||||
raise OpenLLMException('vLLM is not installed. Please install it via `pip install "openllm[vllm]"`.')
|
||||
raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
|
||||
import vllm
|
||||
|
||||
self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer
|
||||
@@ -162,7 +159,9 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from err
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed to initialise vLLMEngine due to the following error:\n{err}'
|
||||
) from err
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
|
||||
@@ -14,22 +14,13 @@ import psutil
|
||||
import bentoml
|
||||
from bentoml._internal.resource import get_resource, system_resources
|
||||
from bentoml._internal.runner.strategy import THREAD_ENVS
|
||||
from openllm_core._typing_compat import overload
|
||||
from openllm_core.utils import DEBUG, ReprMixin
|
||||
|
||||
|
||||
class DynResource(t.Protocol):
|
||||
resource_id: t.ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def from_system(cls) -> t.Sequence[t.Any]: ...
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _strtoul(s: str) -> int:
|
||||
"""Return -1 or positive integer sequence string starts with,."""
|
||||
# Return -1 or positive integer sequence string starts with.
|
||||
if not s:
|
||||
return -1
|
||||
idx = 0
|
||||
@@ -55,21 +46,6 @@ def _parse_list_with_prefix(lst: str, prefix: str) -> list[str]:
|
||||
return rcs
|
||||
|
||||
|
||||
_STACK_LEVEL = 3
|
||||
|
||||
|
||||
@overload # variant: default callback
|
||||
def _parse_visible_devices() -> list[str] | None: ...
|
||||
|
||||
|
||||
@overload # variant: specify None, and respect_env
|
||||
def _parse_visible_devices(default_var: None, *, respect_env: t.Literal[True]) -> list[str] | None: ...
|
||||
|
||||
|
||||
@overload # variant: default var is something other than None
|
||||
def _parse_visible_devices(default_var: str = ..., *, respect_env: t.Literal[False]) -> list[str]: ...
|
||||
|
||||
|
||||
def _parse_visible_devices(default_var: str | None = None, respect_env: bool = True) -> list[str] | None:
|
||||
"""CUDA_VISIBLE_DEVICES aware with default var for parsing spec."""
|
||||
if respect_env:
|
||||
@@ -101,146 +77,136 @@ def _parse_visible_devices(default_var: str | None = None, respect_env: bool = T
|
||||
return [str(i) for i in rc]
|
||||
|
||||
|
||||
def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
visible_devices = _parse_visible_devices()
|
||||
if visible_devices is None:
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
if not psutil.LINUX:
|
||||
if DEBUG:
|
||||
logger.debug('AMD GPUs is currently only supported on Linux.')
|
||||
return []
|
||||
# ROCm does not currently have the rocm_smi wheel.
|
||||
# So we need to use the ctypes bindings directly.
|
||||
# we don't want to use CLI because parsing is a pain.
|
||||
sys.path.append('/opt/rocm/libexec/rocm_smi')
|
||||
try:
|
||||
from ctypes import byref, c_uint32
|
||||
|
||||
# refers to https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py
|
||||
from rsmiBindings import rocmsmi, rsmi_status_t
|
||||
|
||||
device_count = c_uint32(0)
|
||||
ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count))
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS:
|
||||
return [str(i) for i in range(device_count.value)]
|
||||
return []
|
||||
# In this case the binary is not found, returning empty list
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
return []
|
||||
finally:
|
||||
sys.path.remove('/opt/rocm/libexec/rocm_smi')
|
||||
else:
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
cuda.cuInit(0)
|
||||
_, dev = cuda.cuDeviceGetCount()
|
||||
return [str(i) for i in range(dev)]
|
||||
except (ImportError, RuntimeError, AttributeError):
|
||||
return []
|
||||
return visible_devices
|
||||
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: int) -> list[str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: list[int | str]) -> list[str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: str) -> list[str]: ...
|
||||
|
||||
|
||||
def _from_spec(cls: type[DynResource], spec: t.Any) -> list[str]:
|
||||
if isinstance(spec, int):
|
||||
if spec in (-1, 0):
|
||||
return []
|
||||
if spec < -1:
|
||||
raise ValueError('Spec cannot be < -1.')
|
||||
return [str(i) for i in range(spec)]
|
||||
elif isinstance(spec, str):
|
||||
if not spec:
|
||||
return []
|
||||
if spec.isdigit():
|
||||
spec = ','.join([str(i) for i in range(_strtoul(spec))])
|
||||
return _parse_visible_devices(spec, respect_env=False)
|
||||
elif isinstance(spec, list):
|
||||
return [str(x) for x in spec]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead."
|
||||
)
|
||||
|
||||
|
||||
def _raw_device_uuid_nvml() -> list[str] | None:
|
||||
from ctypes import CDLL, byref, c_int, c_void_p, create_string_buffer
|
||||
|
||||
try:
|
||||
nvml_h = CDLL('libnvidia-ml.so.1')
|
||||
except Exception:
|
||||
warnings.warn('Failed to find nvidia binding', stacklevel=_STACK_LEVEL)
|
||||
warnings.warn('Failed to find nvidia binding', stacklevel=3)
|
||||
return None
|
||||
|
||||
rc = nvml_h.nvmlInit()
|
||||
if rc != 0:
|
||||
warnings.warn("Can't initialize NVML", stacklevel=_STACK_LEVEL)
|
||||
warnings.warn("Can't initialize NVML", stacklevel=3)
|
||||
return None
|
||||
dev_count = c_int(-1)
|
||||
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
|
||||
if rc != 0:
|
||||
warnings.warn('Failed to get available device from system.', stacklevel=_STACK_LEVEL)
|
||||
warnings.warn('Failed to get available device from system.', stacklevel=3)
|
||||
return None
|
||||
uuids: list[str] = []
|
||||
for idx in range(dev_count.value):
|
||||
dev_id = c_void_p()
|
||||
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
|
||||
if rc != 0:
|
||||
warnings.warn(f'Failed to get device handle for {idx}', stacklevel=_STACK_LEVEL)
|
||||
warnings.warn(f'Failed to get device handle for {idx}', stacklevel=3)
|
||||
return None
|
||||
buf_len = 96
|
||||
buf = create_string_buffer(buf_len)
|
||||
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
|
||||
if rc != 0:
|
||||
warnings.warn(f'Failed to get device UUID for {idx}', stacklevel=_STACK_LEVEL)
|
||||
warnings.warn(f'Failed to get device UUID for {idx}', stacklevel=3)
|
||||
return None
|
||||
uuids.append(buf.raw.decode('ascii').strip('\0'))
|
||||
del nvml_h
|
||||
return uuids
|
||||
|
||||
|
||||
def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
raise RuntimeError(
|
||||
"AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'"
|
||||
)
|
||||
if not all(isinstance(i, str) for i in val):
|
||||
raise ValueError('Input list should be all string type.')
|
||||
class _ResourceMixin:
|
||||
@staticmethod
|
||||
def from_system(cls) -> list[str]:
|
||||
visible_devices = _parse_visible_devices()
|
||||
if visible_devices is None:
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
if not psutil.LINUX:
|
||||
if DEBUG:
|
||||
logger.debug('AMD GPUs is currently only supported on Linux.')
|
||||
return []
|
||||
# ROCm does not currently have the rocm_smi wheel.
|
||||
# So we need to use the ctypes bindings directly.
|
||||
# we don't want to use CLI because parsing is a pain.
|
||||
sys.path.append('/opt/rocm/libexec/rocm_smi')
|
||||
try:
|
||||
from ctypes import byref, c_uint32
|
||||
|
||||
try:
|
||||
from cuda import cuda
|
||||
# refers to https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py
|
||||
from rsmiBindings import rocmsmi, rsmi_status_t
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
# correctly parse handle
|
||||
for el in val:
|
||||
if el.startswith(('GPU-', 'MIG-')):
|
||||
uuids = _raw_device_uuid_nvml()
|
||||
if uuids is None:
|
||||
raise ValueError('Failed to parse available GPUs UUID')
|
||||
if el not in uuids:
|
||||
raise ValueError(f'Given UUID {el} is not found with available UUID (available: {uuids})')
|
||||
elif el.isdigit():
|
||||
err, _ = cuda.cuDeviceGet(int(el))
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise ValueError(f'Failed to get device {el}')
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
device_count = c_uint32(0)
|
||||
ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count))
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS:
|
||||
return [str(i) for i in range(device_count.value)]
|
||||
return []
|
||||
# In this case the binary is not found, returning empty list
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
return []
|
||||
finally:
|
||||
sys.path.remove('/opt/rocm/libexec/rocm_smi')
|
||||
else:
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
cuda.cuInit(0)
|
||||
_, dev = cuda.cuDeviceGetCount()
|
||||
return [str(i) for i in range(dev)]
|
||||
except (ImportError, RuntimeError, AttributeError):
|
||||
return []
|
||||
return visible_devices
|
||||
|
||||
@staticmethod
|
||||
def from_spec(cls, spec) -> list[str]:
|
||||
if isinstance(spec, int):
|
||||
if spec in (-1, 0):
|
||||
return []
|
||||
if spec < -1:
|
||||
raise ValueError('Spec cannot be < -1.')
|
||||
return [str(i) for i in range(spec)]
|
||||
elif isinstance(spec, str):
|
||||
if not spec:
|
||||
return []
|
||||
if spec.isdigit():
|
||||
spec = ','.join([str(i) for i in range(_strtoul(spec))])
|
||||
return _parse_visible_devices(spec, respect_env=False)
|
||||
elif isinstance(spec, list):
|
||||
return [str(x) for x in spec]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate(cls, val: list[t.Any]) -> None:
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
raise RuntimeError(
|
||||
"AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'"
|
||||
)
|
||||
if not all(isinstance(i, str) for i in val):
|
||||
raise ValueError('Input list should be all string type.')
|
||||
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
# correctly parse handle
|
||||
for el in val:
|
||||
if el.startswith(('GPU-', 'MIG-')):
|
||||
uuids = _raw_device_uuid_nvml()
|
||||
if uuids is None:
|
||||
raise ValueError('Failed to parse available GPUs UUID')
|
||||
if el not in uuids:
|
||||
raise ValueError(f'Given UUID {el} is not found with available UUID (available: {uuids})')
|
||||
elif el.isdigit():
|
||||
err, _ = cuda.cuDeviceGet(int(el))
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise ValueError(f'Failed to get device {el}')
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[DynResource]:
|
||||
def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[bentoml.Resource[t.List[str]]]:
|
||||
return types.new_class(
|
||||
name,
|
||||
(bentoml.Resource[t.List[str]], ReprMixin),
|
||||
@@ -248,9 +214,9 @@ def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[
|
||||
lambda ns: ns.update(
|
||||
{
|
||||
'resource_id': resource_kind,
|
||||
'from_spec': classmethod(_from_spec),
|
||||
'from_system': classmethod(_from_system),
|
||||
'validate': classmethod(_validate),
|
||||
'from_spec': classmethod(_ResourceMixin.from_spec),
|
||||
'from_system': classmethod(_ResourceMixin.from_system),
|
||||
'validate': classmethod(_ResourceMixin.validate),
|
||||
'__repr_keys__': property(lambda _: {'resource_id'}),
|
||||
'__doc__': inspect.cleandoc(docstring),
|
||||
'__module__': 'openllm._strategies',
|
||||
@@ -259,15 +225,9 @@ def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[
|
||||
)
|
||||
|
||||
|
||||
# NOTE: we need to hint these t.Literal since mypy is to dumb to infer this as literal 🤦
|
||||
_TPU_RESOURCE: t.Literal['cloud-tpus.google.com/v2'] = 'cloud-tpus.google.com/v2'
|
||||
_AMD_GPU_RESOURCE: t.Literal['amd.com/gpu'] = 'amd.com/gpu'
|
||||
_NVIDIA_GPU_RESOURCE: t.Literal['nvidia.com/gpu'] = 'nvidia.com/gpu'
|
||||
_CPU_RESOURCE: t.Literal['cpu'] = 'cpu'
|
||||
|
||||
NvidiaGpuResource = _make_resource_class(
|
||||
'NvidiaGpuResource',
|
||||
_NVIDIA_GPU_RESOURCE,
|
||||
'nvidia.com/gpu',
|
||||
"""NVIDIA GPU resource.
|
||||
|
||||
This is a modified version of internal's BentoML's NvidiaGpuResource
|
||||
@@ -275,7 +235,7 @@ NvidiaGpuResource = _make_resource_class(
|
||||
)
|
||||
AmdGpuResource = _make_resource_class(
|
||||
'AmdGpuResource',
|
||||
_AMD_GPU_RESOURCE,
|
||||
'amd.com/gpu',
|
||||
"""AMD GPU resource.
|
||||
|
||||
Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
# fmt: off
|
||||
import openllm_client as _client
|
||||
|
||||
|
||||
def __dir__():
|
||||
return sorted(dir(_client))
|
||||
|
||||
|
||||
def __getattr__(it):
|
||||
return getattr(_client, it)
|
||||
def __dir__():return sorted(dir(_client))
|
||||
def __getattr__(it):return getattr(_client, it)
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralBackend, LiteralQuantise
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_bento(
|
||||
model: str, model_id: str | None = None, quantize: LiteralQuantise | None = None, cleanup: bool = False
|
||||
) -> t.Iterator[bentoml.Bento]:
|
||||
logger.info('Building BentoML for %s', model)
|
||||
bento = openllm.build(model, model_id=model_id, quantize=quantize)
|
||||
yield bento
|
||||
if cleanup:
|
||||
logger.info('Deleting %s', bento.tag)
|
||||
bentoml.bentos.delete(bento.tag)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_container(
|
||||
bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any
|
||||
) -> t.Iterator[str]:
|
||||
if isinstance(bento, bentoml.Bento):
|
||||
bento_tag = bento.tag
|
||||
else:
|
||||
bento_tag = bentoml.Tag.from_taglike(bento)
|
||||
if image_tag is None:
|
||||
image_tag = str(bento_tag)
|
||||
executable = shutil.which('docker')
|
||||
if not executable:
|
||||
raise RuntimeError('docker executable not found')
|
||||
try:
|
||||
logger.info('Building container for %s', bento_tag)
|
||||
bentoml.container.build(bento_tag, backend='docker', image_tag=(image_tag,), progress='plain', **attrs)
|
||||
yield image_tag
|
||||
finally:
|
||||
if cleanup:
|
||||
logger.info('Deleting container %s', image_tag)
|
||||
subprocess.check_output([executable, 'rmi', '-f', image_tag])
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def prepare(
|
||||
model: str,
|
||||
model_id: str,
|
||||
backend: LiteralBackend = 'pt',
|
||||
deployment_mode: t.Literal['container', 'local'] = 'local',
|
||||
clean_context: contextlib.ExitStack | None = None,
|
||||
cleanup: bool = True,
|
||||
) -> t.Iterator[str]:
|
||||
if clean_context is None:
|
||||
clean_context = contextlib.ExitStack()
|
||||
cleanup = True
|
||||
llm = openllm.LLM[t.Any, t.Any](model_id, backend=backend)
|
||||
bento_tag = bentoml.Tag.from_taglike(f'{llm.llm_type}-service:{llm.tag.version}')
|
||||
if not bentoml.list(bento_tag):
|
||||
bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
|
||||
else:
|
||||
bento = bentoml.get(bento_tag)
|
||||
container_name = f'openllm-{model}-{llm.llm_type}'.replace('-', '_')
|
||||
if deployment_mode == 'container':
|
||||
container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
|
||||
yield container_name
|
||||
if cleanup:
|
||||
clean_context.close()
|
||||
@@ -44,7 +44,6 @@ from openllm_core.utils import (
|
||||
is_transformers_available as is_transformers_available,
|
||||
is_vllm_available as is_vllm_available,
|
||||
lenient_issubclass as lenient_issubclass,
|
||||
reserve_free_port as reserve_free_port,
|
||||
resolve_filepath as resolve_filepath,
|
||||
resolve_user_filepath as resolve_user_filepath,
|
||||
serde as serde,
|
||||
@@ -21,7 +21,7 @@ from openllm_core._typing_compat import (
|
||||
ParamSpec,
|
||||
get_literal_args,
|
||||
)
|
||||
from openllm_core.utils import DEBUG, resolve_user_filepath
|
||||
from openllm_core.utils import DEBUG, compose, dantic, resolve_user_filepath
|
||||
|
||||
|
||||
class _OpenLLM_GenericInternalConfig(LLMConfig):
|
||||
@@ -134,7 +134,7 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
|
||||
def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
|
||||
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
|
||||
composed = openllm.utils.compose(
|
||||
composed = compose(
|
||||
_OpenLLM_GenericInternalConfig.parse,
|
||||
_http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'),
|
||||
@@ -160,7 +160,7 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC
|
||||
serialisation_option(factory=cog.optgroup),
|
||||
cog.optgroup.option(
|
||||
'--device',
|
||||
type=openllm.utils.dantic.CUDA,
|
||||
type=dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
|
||||
@@ -294,7 +294,7 @@ def _list_models() -> dict[str, t.Any]:
|
||||
"""List all available models within the local store."""
|
||||
from .entrypoint import models_command
|
||||
|
||||
return models_command.main(args=['--show-available', '--quiet'], standalone_mode=False)
|
||||
return models_command.main(args=['--quiet'], standalone_mode=False)
|
||||
|
||||
|
||||
start, start_grpc = codegen.gen_sdk(_start, _serve_grpc=False), codegen.gen_sdk(_start, _serve_grpc=True)
|
||||
|
||||
@@ -1,25 +1,3 @@
|
||||
"""OpenLLM CLI interface.
|
||||
|
||||
This module also contains the SDK to call ``start`` and ``build`` from SDK
|
||||
|
||||
Start any LLM:
|
||||
|
||||
```python
|
||||
openllm.start('mistral', model_id='mistralai/Mistral-7B-v0.1')
|
||||
```
|
||||
|
||||
Build a BentoLLM
|
||||
|
||||
```python
|
||||
bento = openllm.build('mistralai/Mistral-7B-v0.1')
|
||||
```
|
||||
|
||||
Import any LLM into local store
|
||||
```python
|
||||
bentomodel = openllm.import_model('mistralai/Mistral-7B-v0.1')
|
||||
```
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import functools
|
||||
@@ -91,7 +69,6 @@ from openllm_core.utils import (
|
||||
from . import termui
|
||||
from ._factory import (
|
||||
FC,
|
||||
LiteralOutput,
|
||||
_AnyCallable,
|
||||
backend_option,
|
||||
container_registry_option,
|
||||
@@ -1225,7 +1202,11 @@ def models_command(**_: t.Any) -> dict[t.LiteralString, ModelItem]:
|
||||
@model_name_argument(required=False)
|
||||
@click.option('-y', '--yes', '--assume-yes', is_flag=True, help='Skip confirmation when deleting a specific model')
|
||||
@click.option(
|
||||
'--include-bentos/--no-include-bentos', is_flag=True, default=False, help='Whether to also include pruning bentos.'
|
||||
'--include-bentos/--no-include-bentos',
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
default=True,
|
||||
help='Whether to also include pruning bentos.',
|
||||
)
|
||||
@inject
|
||||
@click.pass_context
|
||||
@@ -1233,11 +1214,11 @@ def prune_command(
|
||||
ctx: click.Context,
|
||||
model_name: str | None,
|
||||
yes: bool,
|
||||
include_bentos: bool,
|
||||
model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
**_: t.Any,
|
||||
) -> None:
|
||||
"""Remove all saved models, (and optionally bentos) built with OpenLLM locally.
|
||||
"""Remove all saved models, and bentos built with OpenLLM locally.
|
||||
|
||||
\b
|
||||
If a model type is passed, then only prune models for that given model type.
|
||||
@@ -1252,18 +1233,15 @@ def prune_command(
|
||||
(m, store)
|
||||
for m, store in available
|
||||
if 'model_name' in m.info.labels and m.info.labels['model_name'] == inflection.underscore(model_name)
|
||||
] + [
|
||||
(b, bento_store)
|
||||
for b in bentoml.bentos.list()
|
||||
if 'start_name' in b.info.labels and b.info.labels['start_name'] == inflection.underscore(model_name)
|
||||
]
|
||||
if model_name is None:
|
||||
available += [
|
||||
(b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels
|
||||
]
|
||||
if include_bentos:
|
||||
if model_name is not None:
|
||||
available += [
|
||||
(b, bento_store)
|
||||
for b in bentoml.bentos.list()
|
||||
if 'start_name' in b.info.labels and b.info.labels['start_name'] == inflection.underscore(model_name)
|
||||
]
|
||||
else:
|
||||
available += [
|
||||
(b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels
|
||||
]
|
||||
|
||||
for store_item, store in available:
|
||||
if yes:
|
||||
@@ -1316,69 +1294,6 @@ def shared_client_options(f: _AnyCallable | None = None) -> t.Callable[[FC], FC]
|
||||
return compose(*options)(f) if f is not None else compose(*options)
|
||||
|
||||
|
||||
@cli.command(hidden=True)
|
||||
@click.argument('task', type=click.STRING, metavar='TASK')
|
||||
@shared_client_options
|
||||
@click.option(
|
||||
'--agent',
|
||||
type=click.Choice(['hf']),
|
||||
default='hf',
|
||||
help='Whether to interact with Agents from given Server endpoint.',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
'--remote',
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help='Whether or not to use remote tools (inference endpoints) instead of local ones.',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
'--opt',
|
||||
help="Define prompt options. (format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]',
|
||||
)
|
||||
def instruct_command(
|
||||
endpoint: str,
|
||||
timeout: int,
|
||||
agent: LiteralString,
|
||||
output: LiteralOutput,
|
||||
remote: bool,
|
||||
task: str,
|
||||
_memoized: DictStrAny,
|
||||
**attrs: t.Any,
|
||||
) -> str:
|
||||
"""Instruct agents interactively for given tasks, from a terminal.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm instruct --endpoint http://12.323.2.1:3000 \\
|
||||
"Is the following `text` (in Spanish) positive or negative?" \\
|
||||
--text "¡Este es un API muy agradable!"
|
||||
```
|
||||
"""
|
||||
raise click.ClickException("'instruct' is currently disabled")
|
||||
# client = openllm.client.HTTPClient(endpoint, timeout=timeout)
|
||||
#
|
||||
# try:
|
||||
# client.call('metadata')
|
||||
# except http.client.BadStatusLine:
|
||||
# raise click.ClickException(f'{endpoint} is neither a HTTP server nor reachable.') from None
|
||||
# if agent == 'hf':
|
||||
# _memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
# client._hf_agent.set_stream(logger.info)
|
||||
# if output != 'porcelain': termui.echo(f"Sending the following prompt ('{task}') with the following vars: {_memoized}", fg='magenta')
|
||||
# result = client.ask_agent(task, agent_type=agent, return_code=False, remote=remote, **_memoized)
|
||||
# if output == 'json': termui.echo(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
# else: termui.echo(result, fg='white')
|
||||
# return result
|
||||
# else:
|
||||
# raise click.BadOptionUsage('agent', f'Unknown agent type {agent}')
|
||||
|
||||
|
||||
@cli.command()
|
||||
@shared_client_options
|
||||
@click.option(
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
{
|
||||
"configuration": {
|
||||
"generation_config": {
|
||||
"diversity_penalty": 0.0,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"encoder_repetition_penalty": 1.0,
|
||||
"epsilon_cutoff": 0.0,
|
||||
"eta_cutoff": 0.0,
|
||||
"length_penalty": 1.0,
|
||||
"max_new_tokens": 10,
|
||||
"min_length": 0,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"remove_invalid_values": false,
|
||||
"renormalize_logits": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 0.9,
|
||||
"top_k": 50,
|
||||
"top_p": 0.9,
|
||||
"typical_p": 1.0,
|
||||
"use_cache": true
|
||||
}
|
||||
},
|
||||
"responses": [
|
||||
"life is a complete physical life"
|
||||
]
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
{
|
||||
"configuration": {
|
||||
"generation_config": {
|
||||
"diversity_penalty": 0.0,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"encoder_repetition_penalty": 1.0,
|
||||
"epsilon_cutoff": 0.0,
|
||||
"eta_cutoff": 0.0,
|
||||
"length_penalty": 1.0,
|
||||
"max_new_tokens": 10,
|
||||
"min_length": 0,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"remove_invalid_values": false,
|
||||
"renormalize_logits": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 0.9,
|
||||
"top_k": 50,
|
||||
"top_p": 0.9,
|
||||
"typical_p": 1.0,
|
||||
"use_cache": true
|
||||
}
|
||||
},
|
||||
"responses": [
|
||||
"life is a state"
|
||||
]
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
{
|
||||
"configuration": {
|
||||
"format_outputs": false,
|
||||
"generation_config": {
|
||||
"diversity_penalty": 0.0,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"encoder_repetition_penalty": 1.0,
|
||||
"epsilon_cutoff": 0.0,
|
||||
"eta_cutoff": 0.0,
|
||||
"length_penalty": 1.0,
|
||||
"max_new_tokens": 20,
|
||||
"min_length": 0,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"remove_invalid_values": false,
|
||||
"renormalize_logits": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 0.75,
|
||||
"top_k": 15,
|
||||
"top_p": 1.0,
|
||||
"typical_p": 1.0,
|
||||
"use_cache": true
|
||||
}
|
||||
},
|
||||
"responses": [
|
||||
"What is Deep learning?\nDeep learning is a new way of studying the content and making an informed decision. It is the"
|
||||
]
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
{
|
||||
"configuration": {
|
||||
"format_outputs": false,
|
||||
"generation_config": {
|
||||
"diversity_penalty": 0.0,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"encoder_repetition_penalty": 1.0,
|
||||
"epsilon_cutoff": 0.0,
|
||||
"eta_cutoff": 0.0,
|
||||
"length_penalty": 1.0,
|
||||
"max_new_tokens": 20,
|
||||
"min_length": 0,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"remove_invalid_values": false,
|
||||
"renormalize_logits": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 0.75,
|
||||
"top_k": 15,
|
||||
"top_p": 1.0,
|
||||
"typical_p": 1.0,
|
||||
"use_cache": true
|
||||
}
|
||||
},
|
||||
"responses": [
|
||||
"What is Deep learning?\n\nDeep learning is a new, highly-advanced, and powerful tool for the deep learning"
|
||||
]
|
||||
}
|
||||
@@ -1,266 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import typing as t
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import attr
|
||||
import docker
|
||||
import docker.errors
|
||||
import docker.types
|
||||
import orjson
|
||||
import pytest
|
||||
from syrupy.extensions.json import JSONSnapshotExtension
|
||||
|
||||
import openllm
|
||||
from bentoml._internal.types import LazyType
|
||||
from openllm._llm import self
|
||||
from openllm_core._typing_compat import DictStrAny, ListAny, LiteralQuantise
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import subprocess
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from syrupy.types import PropertyFilter, PropertyMatcher, SerializableData, SerializedData
|
||||
|
||||
from openllm.client import BaseAsyncClient
|
||||
|
||||
|
||||
class ResponseComparator(JSONSnapshotExtension):
|
||||
def serialize(
|
||||
self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None
|
||||
) -> SerializedData:
|
||||
if LazyType(ListAny).isinstance(data):
|
||||
data = [d.unmarshaled for d in data]
|
||||
else:
|
||||
data = data.unmarshaled
|
||||
data = self._filter(data=data, depth=0, path=(), exclude=exclude, matcher=matcher)
|
||||
return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
|
||||
|
||||
def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
|
||||
def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
|
||||
try:
|
||||
data = orjson.loads(data)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise ValueError(f'Failed to decode JSON data: {data}') from err
|
||||
if LazyType(DictStrAny).isinstance(data):
|
||||
return openllm.GenerationOutput(**data)
|
||||
elif LazyType(ListAny).isinstance(data):
|
||||
return [openllm.GenerationOutput(**d) for d in data]
|
||||
else:
|
||||
raise NotImplementedError(f'Data {data} has unsupported type.')
|
||||
|
||||
serialized_data = convert_data(serialized_data)
|
||||
snapshot_data = convert_data(snapshot_data)
|
||||
|
||||
if LazyType(ListAny).isinstance(serialized_data):
|
||||
serialized_data = [serialized_data]
|
||||
if LazyType(ListAny).isinstance(snapshot_data):
|
||||
snapshot_data = [snapshot_data]
|
||||
|
||||
def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
|
||||
return len(s.outputs) == len(t.outputs)
|
||||
|
||||
return len(serialized_data) == len(snapshot_data) and all(
|
||||
[eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def response_snapshot(snapshot: SnapshotAssertion):
|
||||
return snapshot.use_extension(ResponseComparator)
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class _Handle(ABC):
|
||||
port: int
|
||||
deployment_mode: t.Literal['container', 'local']
|
||||
|
||||
client: BaseAsyncClient[t.Any] = attr.field(init=False)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
def __attrs_init__(self, *args: t.Any, **attrs: t.Any): ...
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.client = openllm.client.AsyncHTTPClient(f'http://localhost:{self.port}')
|
||||
|
||||
@abstractmethod
|
||||
def status(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def health(self, timeout: int = 240):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if not self.status():
|
||||
raise RuntimeError(f'Failed to initialise {self.__class__.__name__}')
|
||||
await self.client.health()
|
||||
try:
|
||||
await self.client.query('sanity')
|
||||
return
|
||||
except Exception:
|
||||
time.sleep(1)
|
||||
raise RuntimeError(f'Handle failed to initialise within {timeout} seconds.')
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class LocalHandle(_Handle):
|
||||
process: subprocess.Popen[bytes]
|
||||
|
||||
def __init__(self, process: subprocess.Popen[bytes], port: int, deployment_mode: t.Literal['container', 'local']):
|
||||
self.__attrs_init__(port, deployment_mode, process)
|
||||
|
||||
def status(self) -> bool:
|
||||
return self.process.poll() is None
|
||||
|
||||
|
||||
class HandleProtocol(t.Protocol):
|
||||
@contextlib.contextmanager
|
||||
def __call__(
|
||||
*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None
|
||||
) -> t.Generator[_Handle, None, None]: ...
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class DockerHandle(_Handle):
|
||||
container_name: str
|
||||
docker_client: docker.DockerClient
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
docker_client: docker.DockerClient,
|
||||
container_name: str,
|
||||
port: int,
|
||||
deployment_mode: t.Literal['container', 'local'],
|
||||
):
|
||||
self.__attrs_init__(port, deployment_mode, container_name, docker_client)
|
||||
|
||||
def status(self) -> bool:
|
||||
container = self.docker_client.containers.get(self.container_name)
|
||||
return container.status in ['running', 'created']
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _local_handle(
|
||||
model: str,
|
||||
model_id: str,
|
||||
image_tag: str,
|
||||
deployment_mode: t.Literal['container', 'local'],
|
||||
quantize: LiteralQuantise | None = None,
|
||||
*,
|
||||
_serve_grpc: bool = False,
|
||||
):
|
||||
with openllm.utils.reserve_free_port() as port:
|
||||
pass
|
||||
|
||||
if not _serve_grpc:
|
||||
proc = openllm.start(
|
||||
model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True
|
||||
)
|
||||
else:
|
||||
proc = openllm.start_grpc(
|
||||
model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True
|
||||
)
|
||||
|
||||
yield LocalHandle(proc, port, deployment_mode)
|
||||
proc.terminate()
|
||||
proc.wait(60)
|
||||
|
||||
process_output = proc.stdout.read()
|
||||
print(process_output, file=sys.stderr)
|
||||
|
||||
proc.stdout.close()
|
||||
if proc.stderr:
|
||||
proc.stderr.close()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _container_handle(
|
||||
model: str,
|
||||
model_id: str,
|
||||
image_tag: str,
|
||||
deployment_mode: t.Literal['container', 'local'],
|
||||
quantize: LiteralQuantise | None = None,
|
||||
*,
|
||||
_serve_grpc: bool = False,
|
||||
):
|
||||
with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
|
||||
pass
|
||||
container_name = f'openllm-{model}-{self(model_id)}'.replace('-', '_')
|
||||
client = docker.from_env()
|
||||
try:
|
||||
container = client.containers.get(container_name)
|
||||
container.stop()
|
||||
container.wait()
|
||||
container.remove()
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
|
||||
args = ['serve' if not _serve_grpc else 'serve-grpc']
|
||||
|
||||
env: DictStrAny = {}
|
||||
|
||||
if quantize is not None:
|
||||
env['OPENLLM_QUANTIZE'] = quantize
|
||||
|
||||
gpus = openllm.utils.device_count() or -1
|
||||
devs = [docker.types.DeviceRequest(count=gpus, capabilities=[['gpu']])] if gpus > 0 else None
|
||||
|
||||
container = client.containers.run(
|
||||
image_tag,
|
||||
command=args,
|
||||
name=container_name,
|
||||
environment=env,
|
||||
auto_remove=False,
|
||||
detach=True,
|
||||
device_requests=devs,
|
||||
ports={'3000/tcp': port, '3001/tcp': prom_port},
|
||||
)
|
||||
|
||||
yield DockerHandle(client, container.name, port, deployment_mode)
|
||||
|
||||
try:
|
||||
container.stop()
|
||||
container.wait()
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
|
||||
container_output = container.logs().decode('utf-8')
|
||||
print(container_output, file=sys.stderr)
|
||||
|
||||
container.remove()
|
||||
|
||||
|
||||
@pytest.fixture(scope='session', autouse=True)
|
||||
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
|
||||
stack = contextlib.ExitStack()
|
||||
yield stack
|
||||
stack.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(params=['container', 'local'], scope='session')
|
||||
def deployment_mode(request: pytest.FixtureRequest) -> str:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal['container', 'local']):
|
||||
if deployment_mode == 'container':
|
||||
return functools.partial(_container_handle, deployment_mode=deployment_mode)
|
||||
elif deployment_mode == 'local':
|
||||
return functools.partial(_local_handle, deployment_mode=deployment_mode)
|
||||
else:
|
||||
raise ValueError(f'Unknown deployment mode: {deployment_mode}')
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import pytest
|
||||
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import contextlib
|
||||
|
||||
from .conftest import HandleProtocol, ResponseComparator, _Handle
|
||||
|
||||
model = 'flan_t5'
|
||||
model_id = 'google/flan-t5-small'
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def flan_t5_handle(
|
||||
handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack
|
||||
):
|
||||
with openllm.testing.prepare(
|
||||
model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context
|
||||
) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def flan_t5(flan_t5_handle: _Handle):
|
||||
await flan_t5_handle.health(240)
|
||||
return flan_t5_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_flan_t5(flan_t5: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
|
||||
client = await flan_t5
|
||||
response = await client.query('What is the meaning of life?', max_new_tokens=10, top_p=0.9, return_response='attrs')
|
||||
|
||||
assert response.configuration['generation_config']['max_new_tokens'] == 10
|
||||
assert response == response_snapshot
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import pytest
|
||||
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import contextlib
|
||||
|
||||
from .conftest import HandleProtocol, ResponseComparator, _Handle
|
||||
|
||||
model = 'opt'
|
||||
model_id = 'facebook/opt-125m'
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def opt_125m_handle(
|
||||
handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack
|
||||
):
|
||||
with openllm.testing.prepare(
|
||||
model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context
|
||||
) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def opt_125m(opt_125m_handle: _Handle):
|
||||
await opt_125m_handle.health(240)
|
||||
return opt_125m_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_opt_125m(opt_125m: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
|
||||
client = await opt_125m
|
||||
response = await client.query('What is Deep learning?', max_new_tokens=20, return_response='attrs')
|
||||
|
||||
assert response.configuration['generation_config']['max_new_tokens'] == 20
|
||||
assert response == response_snapshot
|
||||
Reference in New Issue
Block a user