refactor(config): simplify configuration and update start CLI output (#611)

* chore(config): simplify configuration and update start CLI output
handling

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: remove state and message sent after server lifecycle

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: update color stream and refactor reusable logic

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: update documentations and mypy

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-11 22:36:10 -05:00
committed by GitHub
parent 36559a5ab5
commit 7438005c04
19 changed files with 414 additions and 469 deletions

View File

@@ -9,44 +9,17 @@ deploy, and monitor any LLMs with ease.
* Native integration with BentoML and LangChain for custom LLM apps
"""
from __future__ import annotations
import logging as _logging
import os as _os
import typing as _t
import pathlib as _pathlib
import warnings as _warnings
from pathlib import Path as _Path
import openllm_core
from openllm_core._configuration import GenerationConfig as GenerationConfig
from openllm_core._configuration import LLMConfig as LLMConfig
from openllm_core._configuration import SamplingParams as SamplingParams
from openllm_core._schemas import GenerationInput as GenerationInput
from openllm_core._schemas import GenerationOutput as GenerationOutput
from openllm_core._schemas import MetadataOutput as MetadataOutput
from openllm_core.config import CONFIG_MAPPING as CONFIG_MAPPING
from openllm_core.config import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
from openllm_core.config import AutoConfig as AutoConfig
from openllm_core.config import BaichuanConfig as BaichuanConfig
from openllm_core.config import ChatGLMConfig as ChatGLMConfig
from openllm_core.config import DollyV2Config as DollyV2Config
from openllm_core.config import FalconConfig as FalconConfig
from openllm_core.config import FlanT5Config as FlanT5Config
from openllm_core.config import GPTNeoXConfig as GPTNeoXConfig
from openllm_core.config import LlamaConfig as LlamaConfig
from openllm_core.config import MistralConfig as MistralConfig
from openllm_core.config import MPTConfig as MPTConfig
from openllm_core.config import OPTConfig as OPTConfig
from openllm_core.config import StableLMConfig as StableLMConfig
from openllm_core.config import StarCoderConfig as StarCoderConfig
from . import exceptions as exceptions
from . import utils as utils
if openllm_core.utils.DEBUG:
openllm_core.utils.set_debug_mode(True)
openllm_core.utils.set_quiet_mode(False)
if utils.DEBUG:
utils.set_debug_mode(True)
utils.set_quiet_mode(False)
_logging.basicConfig(level=_logging.NOTSET)
else:
# configuration for bitsandbytes before import
@@ -64,68 +37,36 @@ else:
'ignore', message='Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated'
)
_import_structure: dict[str, list[str]] = {
'exceptions': [],
'client': ['HTTPClient', 'AsyncHTTPClient'],
'bundle': [],
'playground': [],
'testing': [],
'prompts': ['PromptTemplate'],
'protocol': [],
'utils': [],
'_deprecated': ['Runner'],
'_strategies': ['CascadingResourceStrategy', 'get_resource'],
'entrypoints': ['mount_entrypoints'],
'serialisation': ['ggml', 'transformers'],
'cli._sdk': ['start', 'start_grpc', 'build', 'import_model', 'list_models'],
'_quantisation': ['infer_quantisation_config'],
'_llm': ['LLM', 'LLMRunner', 'LLMRunnable'],
'_generation': [
'StopSequenceCriteria',
'StopOnTokens',
'LogitsProcessorList',
'StoppingCriteriaList',
'prepare_logits_processor',
],
}
COMPILED = _Path(__file__).suffix in ('.pyd', '.so')
if _t.TYPE_CHECKING:
from . import bundle as bundle
from . import cli as cli
from . import client as client
from . import playground as playground
from . import serialisation as serialisation
from . import testing as testing
from . import utils as utils
from .client import HTTPClient as HTTPClient
from .client import AsyncHTTPClient as AsyncHTTPClient
from ._deprecated import Runner as Runner
from ._generation import LogitsProcessorList as LogitsProcessorList
from ._generation import StopOnTokens as StopOnTokens
from ._generation import StoppingCriteriaList as StoppingCriteriaList
from ._generation import StopSequenceCriteria as StopSequenceCriteria
from ._generation import prepare_logits_processor as prepare_logits_processor
from ._llm import LLM as LLM
from ._llm import LLMRunnable as LLMRunnable
from ._llm import LLMRunner as LLMRunner
from ._quantisation import infer_quantisation_config as infer_quantisation_config
from ._strategies import CascadingResourceStrategy as CascadingResourceStrategy
from ._strategies import get_resource as get_resource
from .cli._sdk import build as build
from .cli._sdk import import_model as import_model
from .cli._sdk import list_models as list_models
from .cli._sdk import start as start
from .cli._sdk import start_grpc as start_grpc
from .entrypoints import mount_entrypoints as mount_entrypoints
from .prompts import PromptTemplate as PromptTemplate
from .protocol import openai as openai
from .serialisation import ggml as ggml
from .serialisation import transformers as transformers
COMPILED = _pathlib.Path(__file__).suffix in ('.pyd', '.so')
# NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__
__lazy = openllm_core.utils.LazyModule(
__name__, globals()['__file__'], _import_structure, extra_objects={'COMPILED': COMPILED}
__lazy = utils.LazyModule(
__name__,
globals()['__file__'],
{
'exceptions': [],
'client': ['HTTPClient', 'AsyncHTTPClient'],
'bundle': [],
'playground': [],
'testing': [],
'protocol': [],
'utils': [],
'_deprecated': ['Runner'],
'_strategies': ['CascadingResourceStrategy', 'get_resource'],
'entrypoints': ['mount_entrypoints'],
'serialisation': ['ggml', 'transformers'],
'cli._sdk': ['start', 'start_grpc', 'build', 'import_model', 'list_models'],
'_quantisation': ['infer_quantisation_config'],
'_llm': ['LLM', 'LLMRunner', 'LLMRunnable'],
'_generation': [
'StopSequenceCriteria',
'StopOnTokens',
'LogitsProcessorList',
'StoppingCriteriaList',
'prepare_logits_processor',
],
},
extra_objects={'COMPILED': COMPILED},
)
__all__ = __lazy.__all__
__dir__ = __lazy.__dir__

View File

@@ -0,0 +1,54 @@
from openllm_core._configuration import GenerationConfig as GenerationConfig
from openllm_core._configuration import LLMConfig as LLMConfig
from openllm_core._configuration import SamplingParams as SamplingParams
from openllm_core._schemas import GenerationInput as GenerationInput
from openllm_core._schemas import GenerationOutput as GenerationOutput
from openllm_core._schemas import MetadataOutput as MetadataOutput
from openllm_core.config import CONFIG_MAPPING as CONFIG_MAPPING
from openllm_core.config import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
from openllm_core.config import AutoConfig as AutoConfig
from openllm_core.config import BaichuanConfig as BaichuanConfig
from openllm_core.config import ChatGLMConfig as ChatGLMConfig
from openllm_core.config import DollyV2Config as DollyV2Config
from openllm_core.config import FalconConfig as FalconConfig
from openllm_core.config import FlanT5Config as FlanT5Config
from openllm_core.config import GPTNeoXConfig as GPTNeoXConfig
from openllm_core.config import LlamaConfig as LlamaConfig
from openllm_core.config import MistralConfig as MistralConfig
from openllm_core.config import MPTConfig as MPTConfig
from openllm_core.config import OPTConfig as OPTConfig
from openllm_core.config import StableLMConfig as StableLMConfig
from openllm_core.config import StarCoderConfig as StarCoderConfig
from . import exceptions as exceptions
from . import bundle as bundle
from . import cli as cli
from . import client as client
from . import playground as playground
from . import serialisation as serialisation
from . import testing as testing
from . import utils as utils
from ._deprecated import Runner as Runner
from ._generation import LogitsProcessorList as LogitsProcessorList
from ._generation import StopOnTokens as StopOnTokens
from ._generation import StoppingCriteriaList as StoppingCriteriaList
from ._generation import StopSequenceCriteria as StopSequenceCriteria
from ._generation import prepare_logits_processor as prepare_logits_processor
from ._llm import LLM as LLM
from ._llm import LLMRunnable as LLMRunnable
from ._llm import LLMRunner as LLMRunner
from ._quantisation import infer_quantisation_config as infer_quantisation_config
from ._strategies import CascadingResourceStrategy as CascadingResourceStrategy
from ._strategies import get_resource as get_resource
from .cli._sdk import build as build
from .cli._sdk import import_model as import_model
from .cli._sdk import list_models as list_models
from .cli._sdk import start as start
from .cli._sdk import start_grpc as start_grpc
from .client import AsyncHTTPClient as AsyncHTTPClient
from .client import HTTPClient as HTTPClient
from .entrypoints import mount_entrypoints as mount_entrypoints
from .protocol import openai as openai
from .serialisation import ggml as ggml
from .serialisation import transformers as transformers
COMPILED: bool = ...

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
import logging
import os
import typing as t
import warnings
@@ -6,39 +7,28 @@ import warnings
import openllm
from openllm_core._typing_compat import LiteralBackend
from openllm_core._typing_compat import ParamSpec
from openllm_core.utils import first_not_none
from openllm_core.utils import is_vllm_available
if t.TYPE_CHECKING:
from openllm_core import LLMConfig
from openllm_core._typing_compat import ParamSpec
P = ParamSpec('P')
from ._llm import LLMRunner
P = ParamSpec('P')
_object_setattr = object.__setattr__
logger = logging.getLogger(__name__)
def _mark_deprecated(fn: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
_object_setattr(fn, '__deprecated__', True)
return fn
@_mark_deprecated
def Runner(
model_name: str,
ensure_available: bool = True,
init_local: bool = False,
backend: LiteralBackend | None = None,
llm_config: LLMConfig | None = None,
llm_config: openllm.LLMConfig | None = None,
**attrs: t.Any,
) -> LLMRunner[t.Any, t.Any]:
) -> openllm.LLMRunner[t.Any, t.Any]:
"""Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'.
> [!WARNING]
> This method is now deprecated and in favor of 'openllm.LLM.runner'
> This method is now deprecated and in favor of 'openllm.LLM'
```python
runner = openllm.Runner("dolly-v2")
@@ -65,6 +55,10 @@ def Runner(
if llm_config is None:
llm_config = openllm.AutoConfig.for_model(model_name)
if not ensure_available:
logger.warning(
"'ensure_available=False' won't have any effect as LLM will always check to download the model on initialisation."
)
model_id = attrs.get('model_id', default=os.getenv('OPENLLM_MODEL_ID', llm_config['default_id']))
_RUNNER_MSG = f"""\
Using 'openllm.Runner' is now deprecated. Make sure to switch to the following syntax:
@@ -99,10 +93,4 @@ def Runner(
return llm.runner
_DEPRECATED = {k: v for k, v in locals().items() if getattr(v, '__deprecated__', False)}
__all__ = list(_DEPRECATED)
def __dir__() -> list[str]:
return sorted(_DEPRECATED.keys())
__all__ = ['Runner']

View File

@@ -9,6 +9,8 @@ import typing as t
import attr
import inflection
import orjson
import torch
import transformers
from huggingface_hub import hf_hub_download
@@ -44,7 +46,6 @@ from openllm_core.utils import first_not_none
from openllm_core.utils import flatten_attrs
from openllm_core.utils import generate_hash_from_file
from openllm_core.utils import is_peft_available
from openllm_core.utils import is_torch_available
from openllm_core.utils import resolve_filepath
from openllm_core.utils import validate_is_path
@@ -57,8 +58,6 @@ from .serialisation.constants import PEFT_CONFIG_NAME
if t.TYPE_CHECKING:
import peft
import torch
import transformers
from bentoml._internal.runner.runnable import RunnableMethod
from bentoml._internal.runner.runner import RunnerMethod
@@ -68,8 +67,6 @@ if t.TYPE_CHECKING:
from openllm_core.utils.representation import ReprArgs
else:
transformers = LazyLoader('transformers', globals(), 'transformers')
torch = LazyLoader('torch', globals(), 'torch')
peft = LazyLoader('peft', globals(), 'peft')
ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['peft.PeftConfig', str]]]
@@ -146,12 +143,6 @@ class LLM(t.Generic[M, T], ReprMixin):
__llm_adapter_map__: t.Optional[ResolvedAdapterMap] = None
__llm_trust_remote_code__: bool = False
device: 'torch.device | None' = None
def __attrs_post_init__(self) -> None:
if self.__llm_backend__ == 'pt':
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def __init__(
self,
model_id: str,
@@ -233,10 +224,7 @@ class LLM(t.Generic[M, T], ReprMixin):
self._bentomodel = model
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
def _make_tag_components(
self, model_id: str, model_version: str | None, backend: LiteralBackend
) -> tuple[str, str | None]:
"""Return a valid tag name (<backend>-<repo>--<model_id>) and its tag version."""
def _make_tag_components(self, model_id, model_version, backend) -> tuple[str, str | None]:
model_id, *maybe_revision = model_id.rsplit(':')
if len(maybe_revision) > 0:
if model_version is not None:
@@ -254,16 +242,16 @@ class LLM(t.Generic[M, T], ReprMixin):
return f'{backend}-{normalise_model_name(model_id)}', model_version
# yapf: disable
def __setattr__(self,attr: str,value: t.Any)->None:
def __setattr__(self,attr,value):
if attr in _reserved_namespace:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
super().__setattr__(attr,value)
@property
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
def __repr_args__(self) -> ReprArgs:
yield 'model_id', self._model_id if not self._local else self.tag.name
yield 'revision', self._revision if self._revision else self.tag.version
yield 'backend', self.__llm_backend__
yield 'type', self.llm_type
def __repr_args__(self):
yield 'model_id',self._model_id if not self._local else self.tag.name
yield 'revision',self._revision if self._revision else self.tag.version
yield 'backend',self.__llm_backend__
yield 'type',self.llm_type
@property
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
@property
@@ -277,11 +265,7 @@ class LLM(t.Generic[M, T], ReprMixin):
@property
def tag(self)->bentoml.Tag:return self._tag
@property
def bentomodel(self)->bentoml.Model:return self._bentomodel
@property
def config(self)->LLMConfig:
if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
return self.__llm_config__
def bentomodel(self)->bentoml.Model:return openllm.serialisation.get(self)
@property
def quantization_config(self)->transformers.BitsAndBytesConfig|transformers.GPTQConfig|transformers.AwqConfig:
if self.__llm_quantization_config__ is None:
@@ -300,31 +284,53 @@ class LLM(t.Generic[M, T], ReprMixin):
def identifying_params(self)->DictStrAny:return {'configuration': self.config.model_dump_json().decode(),'model_ids': orjson.dumps(self.config['model_ids']).decode(),'model_id': self.model_id}
@property
def llm_parameters(self)->tuple[tuple[tuple[t.Any,...],DictStrAny],DictStrAny]:return (self._model_decls,self._model_attrs),self._tokenizer_attrs
# yapf: enable
# NOTE: The section helps with fine-tuning.
# NOTE: This section is the actual model, tokenizer, and config reference here.
@property
def config(self)->LLMConfig:
if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
return self.__llm_config__
@property
def tokenizer(self)->T:
if self.__llm_tokenizer__ is None:self.__llm_tokenizer__=openllm.serialisation.load_tokenizer(self,**self.llm_parameters[-1])
return self.__llm_tokenizer__
@property
def runner(self)->LLMRunner[M, T]:
if self.__llm_runner__ is None:self.__llm_runner__=_RunnerFactory(self)
return self.__llm_runner__
@property
def model(self)->M:
if self.__llm_model__ is None:
model=openllm.serialisation.load_model(self,*self._model_decls,**self._model_attrs)
# If OOM, then it is probably you don't have enough VRAM to run this model.
if self.__llm_backend__ == 'pt':
loaded_in_kbit = getattr(model,'is_loaded_in_8bit',False) or getattr(model,'is_loaded_in_4bit',False) or getattr(model,'is_quantized',False)
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit:
try: model = model.to('cuda')
except Exception as err: raise OpenLLMException(f'Failed to load model into GPU: {err}\n. See https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.') from err
if self.has_adapters:
logger.debug('Applying the following adapters: %s', self.adapter_map)
for adapter_dict in self.adapter_map.values():
for adapter_name, (peft_config, peft_model_id) in adapter_dict.items():
model.load_adapter(peft_model_id, adapter_name, peft_config=peft_config)
self.__llm_model__ = model
return self.__llm_model__
@property
def adapter_map(self) -> ResolvedAdapterMap:
try:
import peft as _ # noqa: F401
except ImportError as err:
raise MissingDependencyError(
"Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'"
) from err
if not self.has_adapters:
raise AttributeError('Adapter map is not available.')
raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'") from err
if not self.has_adapters: raise AttributeError('Adapter map is not available.')
assert self._adapter_map is not None
if self.__llm_adapter_map__ is None:
_map: ResolvedAdapterMap = {k: {} for k in self._adapter_map}
for adapter_type, adapter_tuple in self._adapter_map.items():
base = first_not_none(
self.config['fine_tune_strategies'].get(adapter_type),
default=self.config.make_fine_tune_config(adapter_type),
)
base = first_not_none(self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type))
for adapter in adapter_tuple:
_map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
self.__llm_adapter_map__ = _map
return self.__llm_adapter_map__
# yapf: enable
def prepare_for_training(
self, adapter_type: AdapterType = 'lora', use_gradient_checking: bool = True, **attrs: t.Any
@@ -343,56 +349,11 @@ class LLM(t.Generic[M, T], ReprMixin):
raise ValueError('Adapter should not be specified when fine-tuning.')
model = get_peft_model(
prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking), peft_config
) # type: ignore[no-untyped-call]
)
if DEBUG:
model.print_trainable_parameters() # type: ignore[no-untyped-call]
return model, self.tokenizer
@property
def model(self) -> M:
if self.__llm_model__ is None:
model = openllm.serialisation.load_model(self, *self._model_decls, **self._model_attrs)
# If OOM, then it is probably you don't have enough VRAM to run this model.
if self.__llm_backend__ == 'pt':
if is_torch_available():
loaded_in_kbit = (
getattr(model, 'is_loaded_in_8bit', False)
or getattr(model, 'is_loaded_in_4bit', False)
or getattr(model, 'is_quantized', False)
)
if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
and not loaded_in_kbit
and not isinstance(model, transformers.Pipeline)
):
try:
model = model.to('cuda')
except Exception as err:
raise OpenLLMException(
f'Failed to load {self} into GPU: {err}\nTip: If you run into OOM issue, maybe try different offload strategy. See https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.'
) from err
if self.has_adapters:
logger.debug('Applying the following adapters: %s', self.adapter_map)
for adapter_dict in self.adapter_map.values():
for adapter_name, (peft_config, peft_model_id) in adapter_dict.items():
model.load_adapter(peft_model_id, adapter_name, peft_config=peft_config)
self.__llm_model__ = model
return self.__llm_model__
@property
def tokenizer(self) -> T:
# NOTE: the signature of load_tokenizer here is the wrapper under _wrapped_load_tokenizer
if self.__llm_tokenizer__ is None:
self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
return self.__llm_tokenizer__
@property
def runner(self) -> LLMRunner[M, T]:
if self.__llm_runner__ is None:
self.__llm_runner__ = _RunnerFactory(self)
return self.__llm_runner__
async def generate(
self,
prompt: str | None,

View File

@@ -23,7 +23,6 @@ bentomodel = openllm.import_model('mistralai/Mistral-7B-v0.1')
from __future__ import annotations
import enum
import functools
import importlib.util
import inspect
import itertools
import logging
@@ -31,6 +30,7 @@ import os
import platform
import random
import subprocess
import threading
import time
import traceback
import typing as t
@@ -65,6 +65,7 @@ from openllm_core._typing_compat import LiteralString
from openllm_core._typing_compat import NotRequired
from openllm_core._typing_compat import ParamSpec
from openllm_core._typing_compat import Self
from openllm_core._typing_compat import TypeGuard
from openllm_core.config import CONFIG_MAPPING
from openllm_core.utils import DEBUG_ENV_VAR
from openllm_core.utils import OPTIONAL_DEPENDENCIES
@@ -451,84 +452,31 @@ def start_command(
# XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop('development')
server_attrs.setdefault('production', not development)
wpr = first_not_none(workers_per_resource, default=config['workers_per_resource'])
if isinstance(wpr, str):
if wpr == 'round_robin':
wpr = 1.0
elif wpr == 'conserved':
if device and openllm.utils.device_count() == 0:
termui.echo('--device will have no effect as there is no GPUs available', fg='yellow')
wpr = 1.0
else:
available_gpu = len(device) if device else openllm.utils.device_count()
wpr = 1.0 if available_gpu == 0 else float(1 / available_gpu)
else:
wpr = float(wpr)
elif isinstance(wpr, int):
wpr = float(wpr)
requirements = llm.config['requirements']
if requirements is not None and len(requirements) > 0:
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
if len(missing_requirements) > 0:
termui.echo(f'Make sure to have the following dependencies available: {missing_requirements}', fg='yellow')
start_env = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy())
start_env.update(
{
'OPENLLM_MODEL_ID': model_id,
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
'OPENLLM_SERIALIZATION': serialisation,
'OPENLLM_BACKEND': llm.__llm_backend__,
'OPENLLM_CONFIG': llm.config.model_dump_json(flatten=True).decode(),
}
start_env = process_environ(
config,
server_timeout,
process_workers_per_resource(first_not_none(workers_per_resource, default=config['workers_per_resource']), device),
device,
cors,
model_id,
adapter_map,
serialisation,
llm,
system_message,
prompt_template,
)
if llm._quantise:
start_env['OPENLLM_QUANTIZE'] = str(llm._quantise)
if system_message:
start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template:
start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
server = bentoml.HTTPServer('_service:svc', **server_attrs)
openllm.utils.analytics.track_start_init(llm.config)
openllm.utils.analytics.track_start_init(config)
def next_step(adapter_map: DictStrAny | None, caught_exception: bool = False) -> None:
if caught_exception:
return
cmd_name = f'openllm build {model_id}'
if llm._quantise:
cmd_name += f' --quantize {llm._quantise}'
cmd_name += f' --serialization {serialisation}'
if adapter_map is not None:
cmd_name += ' ' + ' '.join(
[
f'--adapter-id {s}'
for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]
]
)
if not openllm.utils.get_quiet_mode():
termui.info(f"\n\n🚀 Next step: run '{cmd_name}' to create a BentoLLM for '{model_id}'")
_exception = False
if return_process:
server.start(env=start_env, text=True)
if server.process is None:
raise click.ClickException('Failed to start the server.')
return server.process
else:
try:
server.start(env=start_env, text=True, blocking=True)
except KeyboardInterrupt:
_exception = True
except Exception as err:
termui.error(f'Error caught while running LLM Server:\n{err}')
_exception = True
raise
else:
next_step(adapter_map, _exception)
try:
build_bento_instruction(llm, model_id, serialisation, adapter_map)
it = run_server(server.args, start_env, return_process=return_process)
if return_process:
return it
except KeyboardInterrupt:
pass
# NOTE: Return the configuration for telemetry purposes.
return config
@@ -550,7 +498,9 @@ def start_command(
help='Deprecated. Use positional argument instead.',
)
@start_decorator(serve_grpc=True)
@click.pass_context
def start_grpc_command(
ctx: click.Context,
model_id: str,
server_timeout: int,
model_version: str | None,
@@ -626,7 +576,61 @@ def start_grpc_command(
# XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop('development')
server_attrs.setdefault('production', not development)
wpr = first_not_none(workers_per_resource, default=config['workers_per_resource'])
start_env = process_environ(
config,
server_timeout,
process_workers_per_resource(first_not_none(workers_per_resource, default=config['workers_per_resource']), device),
device,
cors,
model_id,
adapter_map,
serialisation,
llm,
system_message,
prompt_template,
)
server = bentoml.GrpcServer('_service:svc', **server_attrs)
openllm.utils.analytics.track_start_init(llm.config)
try:
build_bento_instruction(llm, model_id, serialisation, adapter_map)
it = run_server(server.args, start_env, return_process=return_process)
if return_process:
return it
except KeyboardInterrupt:
pass
# NOTE: Return the configuration for telemetry purposes.
return config
def process_environ(
config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, system_message, prompt_template
) -> t.Dict[str, t.Any]:
environ = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy())
environ.update(
{
'OPENLLM_MODEL_ID': model_id,
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
'OPENLLM_SERIALIZATION': serialisation,
'OPENLLM_BACKEND': llm.__llm_backend__,
'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(),
}
)
if llm._quantise:
environ['OPENLLM_QUANTIZE'] = str(llm._quantise)
if system_message:
environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template:
environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
return environ
def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]) -> TypeGuard[float]:
if isinstance(wpr, str):
if wpr == 'round_robin':
wpr = 1.0
@@ -641,76 +645,82 @@ def start_grpc_command(
wpr = float(wpr)
elif isinstance(wpr, int):
wpr = float(wpr)
return wpr
requirements = llm.config['requirements']
if requirements is not None and len(requirements) > 0:
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
if len(missing_requirements) > 0:
termui.warning(f'Make sure to have the following dependencies available: {missing_requirements}')
start_env = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy())
start_env.update(
{
'OPENLLM_MODEL_ID': model_id,
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
'OPENLLM_SERIALIZATION': serialisation,
'OPENLLM_BACKEND': llm.__llm_backend__,
'OPENLLM_CONFIG': llm.config.model_dump_json().decode(),
}
)
def build_bento_instruction(llm, model_id, serialisation, adapter_map):
cmd_name = f'openllm build {model_id}'
if llm._quantise:
start_env['OPENLLM_QUANTIZE'] = str(llm._quantise)
if system_message:
start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template:
start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
cmd_name += f' --quantize {llm._quantise}'
cmd_name += f' --serialization {serialisation}'
if adapter_map is not None:
cmd_name += ' ' + ' '.join(
[
f'--adapter-id {s}'
for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]
]
)
if not openllm.utils.get_quiet_mode():
termui.info(f"🚀Tip: run '{cmd_name}' to create a BentoLLM for '{model_id}'")
server = bentoml.GrpcServer('_service:svc', **server_attrs)
openllm.utils.analytics.track_start_init(llm.config)
def next_step(adapter_map: DictStrAny | None, caught_exception: bool = False) -> None:
if caught_exception:
return
cmd_name = f'openllm build {model_id}'
if llm._quantise:
cmd_name += f' --quantize {llm._quantise}'
cmd_name += f' --serialization {serialisation}'
if adapter_map is not None:
cmd_name += ' ' + ' '.join(
[
f'--adapter-id {s}'
for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]
]
)
if not openllm.utils.get_quiet_mode():
termui.info(f"\n🚀 Next step: run '{cmd_name}' to create a BentoLLM for '{model_id}'")
_exception = False
if return_process:
server.start(env=start_env, text=True)
if server.process is None:
raise click.ClickException('Failed to start the server.')
return server.process
def pretty_print(line: str):
if 'WARNING' in line:
caller = termui.warning
elif 'INFO' in line:
caller = termui.info
elif 'DEBUG' in line:
caller = termui.debug
elif 'ERROR' in line:
caller = termui.error
else:
try:
server.start(env=start_env, text=True, blocking=True)
except KeyboardInterrupt:
_exception = True
except Exception as err:
termui.error(f'Error caught while running LLM Server:\n{err}')
_exception = True
raise
else:
next_step(adapter_map, _exception)
caller = caller = functools.partial(termui.echo, fg=None)
caller(line.strip())
# NOTE: Return the configuration for telemetry purposes.
return config
def handle(stream, stop_event):
try:
for line in iter(stream.readline, ''):
if stop_event.is_set():
break
pretty_print(line)
finally:
stream.close()
def run_server(args, env, return_process=False) -> subprocess.Popen[bytes] | int:
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, text=True)
if return_process:
return process
stop_event = threading.Event()
# yapf: disable
stdout, stderr = threading.Thread(target=handle, args=(process.stdout, stop_event)), threading.Thread(target=handle, args=(process.stderr, stop_event))
stdout.start(); stderr.start()
try:
process.wait()
except KeyboardInterrupt:
stop_event.set()
process.terminate()
try:
process.wait(0.1)
except subprocess.TimeoutExpired:
# not sure if the process exits cleanly
process.kill()
raise
finally:
stop_event.set()
stdout.join(); stderr.join()
if process.poll() is not None: process.kill()
stdout.join(); stderr.join()
# yapf: disable
return process.returncode
class ItemState(enum.Enum):
NOT_FOUND = 'NOT_FOUND'
ADDED = 'ADDED'
EXISTS = 'EXISTS'
OVERWRITE = 'OVERWRITE'
@@ -808,6 +818,7 @@ def import_command(
model = openllm.serialisation.import_model(llm, trust_remote_code=llm.trust_remote_code)
if llm.__llm_backend__ == 'pt' and is_torch_available() and torch.cuda.is_available():
torch.cuda.empty_cache()
state = ItemState.ADDED
response = ImportModelOutput(state=state, backend=llm.__llm_backend__, tag=str(model.tag))
termui.echo(orjson.dumps(response).decode(), fg='white')
return response
@@ -1052,6 +1063,8 @@ def build_command(
container_registry=container_registry,
container_version_strategy=container_version_strategy,
)
if state != ItemState.OVERWRITE:
state = ItemState.ADDED
except Exception as err:
traceback.print_exc()
raise click.ClickException('Exception caught while building BentoLLM:\n' + str(err)) from err
@@ -1074,10 +1087,10 @@ def build_command(
if machine:
termui.echo(f'__object__:{orjson.dumps(response).decode()}\n\n', fg='white')
elif not get_quiet_mode() and (not push or not containerize):
if not overwrite:
termui.warning(f"Bento for '{model_id}' already exists [{bento}]. To overwrite it pass '--overwrite'.\n")
elif state != ItemState.EXISTS:
if state != ItemState.EXISTS:
termui.info(f"Successfully built Bento '{bento.tag}'.\n")
elif not overwrite:
termui.warning(f"Bento for '{model_id}' already exists [{bento}]. To overwrite it pass '--overwrite'.\n")
if not get_debug_mode():
termui.echo(OPENLLM_FIGLET)
termui.echo('\n📖 Next steps:\n\n', nl=False)

View File

@@ -36,6 +36,16 @@ class Level(enum.IntEnum):
Level.CRITICAL: 'red',
}[self]
@classmethod
def from_logging_level(cls, level: int) -> Level:
return {
logging.DEBUG: Level.DEBUG,
logging.INFO: Level.INFO,
logging.WARNING: Level.WARNING,
logging.ERROR: Level.ERROR,
logging.CRITICAL: Level.CRITICAL,
}[level]
class JsonLog(t.TypedDict):
log_level: Level
@@ -43,13 +53,7 @@ class JsonLog(t.TypedDict):
def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None:
def caller(text: str) -> None:
if get_debug_mode():
logger.log(level.value, text)
else:
echo(JsonLog(log_level=level, content=content), json=True, fg=fg)
caller(orjson.dumps(JsonLog(log_level=level, content=content)).decode())
echo(orjson.dumps(JsonLog(log_level=level, content=content)).decode(), fg=fg, json=True)
warning = functools.partial(log, level=Level.WARNING)
@@ -61,18 +65,17 @@ notset = functools.partial(log, level=Level.NOTSET)
def echo(text: t.Any, fg: str | None = None, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
if json and not isinstance(text, dict):
raise TypeError('text must be a dict')
if json:
text = orjson.loads(text)
if 'content' in text and 'log_level' in text:
content = t.cast(DictStrAny, text)['content']
fg = t.cast(Level, text['log_level']).color
content = text['content']
fg = Level.from_logging_level(text['log_level']).color
else:
content = orjson.dumps(text).decode()
fg = Level.INFO.color if not get_debug_mode() else Level.DEBUG.color
else:
content = t.cast(str, text)
attrs['fg'] = fg if not get_debug_mode() else None
attrs['fg'] = fg
if not get_quiet_mode():
t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(content, **attrs)

View File

@@ -1,27 +1,9 @@
"""OpenLLM Python client.
```python
client = openllm.client.HTTPClient("http://localhost:8080")
client.query("What is the difference between gather and scatter?")
```
"""
from __future__ import annotations
import typing as t
import openllm_client
import openllm_client as _client
if t.TYPE_CHECKING:
from openllm_client import AsyncHTTPClient as AsyncHTTPClient
from openllm_client import HTTPClient as HTTPClient
# from openllm_client import AsyncGrpcClient as AsyncGrpcClient
# from openllm_client import GrpcClient as GrpcClient
def __dir__():
return sorted(dir(_client))
def __dir__() -> t.Sequence[str]:
return sorted(dir(openllm_client))
def __getattr__(it: str) -> t.Any:
return getattr(openllm_client, it)
def __getattr__(it):
return getattr(_client, it)

View File

@@ -0,0 +1,10 @@
"""OpenLLM Python client.
```python
client = openllm.client.HTTPClient("http://localhost:8080")
client.query("What is the difference between gather and scatter?")
```
"""
from openllm_client import HTTPClient as HTTPClient
from openllm_client import AsyncHTTPClient as AsyncHTTPClient

View File

@@ -1,7 +1,3 @@
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
from __future__ import annotations
from openllm_core.exceptions import Error as Error
from openllm_core.exceptions import FineTuneStrategyNotSupportedError as FineTuneStrategyNotSupportedError
from openllm_core.exceptions import ForbiddenAttributeError as ForbiddenAttributeError

View File

@@ -1,3 +0,0 @@
from __future__ import annotations
from openllm_core.prompts import PromptTemplate as PromptTemplate