From 7438005c043fda51b4c4f618ca062d3799a0d93d Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Sat, 11 Nov 2023 22:36:10 -0500 Subject: [PATCH] refactor(config): simplify configuration and update start CLI output (#611) * chore(config): simplify configuration and update start CLI output handling Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: remove state and message sent after server lifecycle Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: update color stream and refactor reusable logic Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: update documentations and mypy Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- .pre-commit-config.yaml | 6 +- README.md | 21 ++ mypy.ini | 3 +- openllm-core/src/openllm_core/__init__.py | 3 - .../src/openllm_core/_configuration.py | 112 +++---- openllm-core/src/openllm_core/utils/lazy.py | 3 + openllm-core/src/openllm_core/utils/peft.py | 13 + openllm-python/src/openllm/__init__.py | 125 ++------ openllm-python/src/openllm/__init__.pyi | 54 ++++ openllm-python/src/openllm/_deprecated.py | 36 +-- openllm-python/src/openllm/_llm.py | 129 +++----- openllm-python/src/openllm/cli/entrypoint.py | 287 +++++++++--------- openllm-python/src/openllm/cli/termui.py | 27 +- openllm-python/src/openllm/client.py | 28 +- openllm-python/src/openllm/client.pyi | 10 + openllm-python/src/openllm/exceptions.py | 4 - openllm-python/src/openllm/prompts.py | 3 - openllm-python/tests/conftest.py | 7 +- openllm-python/tests/models_test.py | 12 +- 19 files changed, 414 insertions(+), 469 deletions(-) create mode 100644 openllm-python/src/openllm/__init__.pyi create mode 100644 openllm-python/src/openllm/client.pyi delete mode 100644 openllm-python/src/openllm/prompts.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a6fc9a9e..b4821c93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ ci: autoupdate_schedule: weekly skip: [eslint, prettier, mypy] autofix_commit_msg: "ci: auto fixes from pre-commit.ci\n\nFor more information, see https://pre-commit.ci" - autoupdate_commit_msg: "ci: pre-commit autoupdate [pre-commit.ci]" + autoupdate_commit_msg: 'ci: pre-commit autoupdate [pre-commit.ci]' default_language_version: python: python3.9 # NOTE: sync with .python-version-default exclude: '.*\.(css|js|svg)$' @@ -24,7 +24,7 @@ repos: language: system always_run: true pass_filenames: false - entry: mypy --strict openllm-client/src/openllm_client/__init__.pyi openllm-core/src/openllm_core/_typing_compat.py openllm-python/src/openllm/serialisation/__init__.pyi + entry: mypy --strict - repo: https://github.com/editorconfig-checker/editorconfig-checker.python rev: '2.7.3' hooks: @@ -53,7 +53,7 @@ repos: openllm-python/tests/models/.* )$ - id: check-yaml - args: ["--unsafe"] + args: ['--unsafe'] - id: check-toml - id: check-docstring-first - id: check-added-large-files diff --git a/README.md b/README.md index d8bd84cd..2a800ae4 100644 --- a/README.md +++ b/README.md @@ -978,9 +978,30 @@ This method is easy to use for one-shot generation use case, but merely served a OpenLLM is not just a standalone product; it's a building block designed to integrate with other powerful tools easily. We currently offer integration with [BentoML](https://github.com/bentoml/BentoML), +[OpenAI's Compatible Endpoints](https://platform.openai.com/docs/api-reference/completions/object), [LangChain](https://github.com/hwchase17/langchain), and [Transformers Agents](https://huggingface.co/docs/transformers/transformers_agents). +### OpenAI Compatible Endpoints + +OpenLLM Server can be used as a drop-in replacement for OpenAI's API. Simply +specify the base_url to `llm-endpoint/v1` and you are good to go: + +```python +import openai +client = openai.OpenAI(base_url='http://localhost:3000/v1', api_key='na') # Here the server is running on localhost:3000 + +completions = client.completions.create( + prompt='Write me a tag line for an ice cream shop.', model=model, max_tokens=64, stream=stream +) +``` + +The compatible endpoints supports `/completions`, `/chat/completions`, and `/models` + +> [!NOTE] +> You can find out OpenAI example clients under the +> [examples](https://github.com/bentoml/OpenLLM/tree/main/examples) folder. + ### BentoML OpenLLM LLM can be integrated as a diff --git a/mypy.ini b/mypy.ini index 4d41bddd..60315ebd 100644 --- a/mypy.ini +++ b/mypy.ini @@ -7,5 +7,4 @@ warn_unused_configs = True ignore_missing_imports = true check_untyped_defs = true warn_unreachable = true -modules = openllm_client -files = openllm-client/src/openllm_client/__init__.pyi|openllm-core/src/openllm_core/_typing_compat.py +files = openllm-client/src/openllm_client/__init__.pyi, openllm-core/src/openllm_core/_typing_compat.py, openllm-client/src/openllm_client/_typing_compat.py, openllm-python/src/openllm/__init__.pyi, openllm-python/src/openllm/client.pyi, openllm-python/src/openllm/serialisation/__init__.pyi diff --git a/openllm-core/src/openllm_core/__init__.py b/openllm-core/src/openllm_core/__init__.py index 6bf020bf..bee467cd 100644 --- a/openllm-core/src/openllm_core/__init__.py +++ b/openllm-core/src/openllm_core/__init__.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from . import exceptions as exceptions from . import prompts as prompts from . import utils as utils @@ -37,4 +35,3 @@ from .config import OPTConfig as OPTConfig from .config import StableLMConfig as StableLMConfig from .config import StarCoderConfig as StarCoderConfig from .prompts import PromptTemplate as PromptTemplate -from .prompts import process_prompt as process_prompt diff --git a/openllm-core/src/openllm_core/_configuration.py b/openllm-core/src/openllm_core/_configuration.py index 3cd6c905..0ce9fd94 100644 --- a/openllm-core/src/openllm_core/_configuration.py +++ b/openllm-core/src/openllm_core/_configuration.py @@ -48,7 +48,6 @@ from .utils import first_not_none from .utils import lenient_issubclass from .utils.peft import PEFT_TASK_TYPE_TARGET_MAPPING from .utils.peft import FineTuneConfig -from .utils.peft import PeftType if t.TYPE_CHECKING: @@ -448,31 +447,6 @@ class _ModelSettingsAttr: return _object_getattribute(self, key) raise KeyError(key) - @classmethod - def default(cls) -> _ModelSettingsAttr: - return cls( - **t.cast( - DictStrAny, - ModelSettings( - default_id='__default__', - model_ids=['__default__'], - architecture='PreTrainedModel', - serialisation='legacy', - backend=('pt', 'vllm'), - name_type='dasherize', - url='', - model_type='causal_lm', - trust_remote_code=False, - requirements=None, - conversation=dict(system_message='', roles=('', ''), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep=''), - add_generation_prompt=False, - timeout=int(36e6), - service_name='', - workers_per_resource=1.0, - ), - ) - ) - # NOTE: The below are dynamically generated by the field_transformer if t.TYPE_CHECKING: # update-config-stubs.py: attrs start @@ -497,56 +471,54 @@ class _ModelSettingsAttr: # update-config-stubs.py: attrs stop -def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]) -> _ModelSettingsAttr: - if 'generation_class' in cl_.__config__: - raise ValueError( - f"'generation_class' shouldn't be defined in '__config__', rather defining all required attributes under '{cl_}.GenerationConfig' instead." - ) +_DEFAULT = _ModelSettingsAttr( + **ModelSettings( + default_id='__default__', + model_ids=['__default__'], + architecture='PreTrainedModel', + serialisation='legacy', + backend=('pt', 'vllm'), + name_type='dasherize', + url='', + model_type='causal_lm', + trust_remote_code=False, + requirements=None, + conversation=dict(system_message='', roles=('', ''), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep=''), + add_generation_prompt=False, + timeout=int(36e6), + service_name='', + workers_per_resource=1.0, + ) +) - required_fields = {k for k, ann in t.get_type_hints(ModelSettings).items() if t.get_origin(ann) is Required} - if any(i not in cl_.__config__ for i in required_fields): - raise ValueError(f"Missing required fields {required_fields} '__config__'.") - _cl_name = cl_.__name__.replace('Config', '') - _settings_attr = cls.default() - has_custom_name = all(i in cl_.__config__ for i in {'model_name', 'start_name'}) - _settings_attr = attr.evolve(_settings_attr, **cl_.__config__) - _final_value_dct: DictStrAny = {} + +def structure_settings(cls: type[LLMConfig], _: type[_ModelSettingsAttr]) -> _ModelSettingsAttr: + _cl_name = cls.__name__.replace('Config', '') + has_custom_name = all(i in cls.__config__ for i in {'model_name', 'start_name'}) + _config = attr.evolve(_DEFAULT, **cls.__config__) + _attr = {} if not has_custom_name: - _final_value_dct['model_name'] = ( - inflection.underscore(_cl_name) if _settings_attr['name_type'] == 'dasherize' else _cl_name.lower() - ) - _final_value_dct['start_name'] = ( - inflection.dasherize(_final_value_dct['model_name']) - if _settings_attr['name_type'] == 'dasherize' - else _final_value_dct['model_name'] + _attr['model_name'] = inflection.underscore(_cl_name) if _config['name_type'] == 'dasherize' else _cl_name.lower() + _attr['start_name'] = ( + inflection.dasherize(_attr['model_name']) if _config['name_type'] == 'dasherize' else _attr['model_name'] ) + model_name = _attr['model_name'] if 'model_name' in _attr else _config.model_name - model_name = _final_value_dct['model_name'] if 'model_name' in _final_value_dct else _settings_attr.model_name + _attr.update( + { + 'service_name': f'generated_{model_name}_service.py', + 'conversation': Conversation(name=model_name, **_config.conversation), + 'fine_tune_strategies': { + ft_config.get('adapter_type', 'lora'): FineTuneConfig.from_config(ft_config, cls) + for ft_config in _config.fine_tune_strategies # ft_config is a dict here before transformer + } + if _config.fine_tune_strategies + else {}, + } + ) - _final_value_dct['service_name'] = f'generated_{model_name}_service.py' - - # NOTE: default conversation templates - _conversation = getattr(_settings_attr, 'conversation', None) - if _conversation is None: - _conversation = dict(system_message='', roles=('', ''), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep='') - _final_value_dct['conversation'] = Conversation(name=model_name, **_conversation) - - # NOTE: The key for fine-tune strategies is 'fine_tune_strategies' - _fine_tune_strategies: tuple[dict[str, t.Any], ...] | None = getattr(_settings_attr, 'fine_tune_strategies', None) - _converted: dict[AdapterType, FineTuneConfig] = {} - if _fine_tune_strategies is not None: - # the given value is a tuple[dict[str, t.Any] ,...] - for _possible_ft_config in _fine_tune_strategies: - _adapter_type: AdapterType | None = _possible_ft_config.pop('adapter_type', None) - if _adapter_type is None: - raise RuntimeError("'adapter_type' is required under config definition (currently missing)'.") - _llm_config_class = _possible_ft_config.pop('llm_config_class', cl_) - _converted[_adapter_type] = FineTuneConfig( - PeftType[_adapter_type], _possible_ft_config, False, _llm_config_class - ) - _final_value_dct['fine_tune_strategies'] = _converted - return attr.evolve(_settings_attr, **_final_value_dct) + return attr.evolve(_config, **_attr) converter.register_structure_hook(_ModelSettingsAttr, structure_settings) diff --git a/openllm-core/src/openllm_core/utils/lazy.py b/openllm-core/src/openllm_core/utils/lazy.py index 8707678c..3fc617f6 100644 --- a/openllm-core/src/openllm_core/utils/lazy.py +++ b/openllm-core/src/openllm_core/utils/lazy.py @@ -239,6 +239,9 @@ class LazyModule(types.ModuleType): value = self._get_module(name) elif name in self._class_to_module.keys(): value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name) + # special branch for openllm_core + elif hasattr(openllm_core, name): + return getattr(openllm_core, name) else: raise AttributeError(f'module {self.__name__} has no attribute {name}') setattr(self, name, value) diff --git a/openllm-core/src/openllm_core/utils/peft.py b/openllm-core/src/openllm_core/utils/peft.py index 5a2301d1..fba211f1 100644 --- a/openllm-core/src/openllm_core/utils/peft.py +++ b/openllm-core/src/openllm_core/utils/peft.py @@ -141,3 +141,16 @@ class FineTuneConfig: inference_mode=inference_mode, adapter_config=config_merger.merge(self.adapter_config, attrs), ) + + @classmethod + def from_config(cls, ft_config: dict[str, t.Any], llm_config_cls: type[LLMConfig]) -> FineTuneConfig: + copied = ft_config.copy() + adapter_type = copied.pop('adapter_type', 'lora') + inference_mode = copied.pop('inference_mode', False) + llm_config_class = copied.pop('llm_confg_class', llm_config_cls) + return cls( + adapter_type=adapter_type, + adapter_config=copied, + inference_mode=inference_mode, + llm_config_class=llm_config_class, + ) diff --git a/openllm-python/src/openllm/__init__.py b/openllm-python/src/openllm/__init__.py index 7ae0ac06..73ce07e1 100644 --- a/openllm-python/src/openllm/__init__.py +++ b/openllm-python/src/openllm/__init__.py @@ -9,44 +9,17 @@ deploy, and monitor any LLMs with ease. * Native integration with BentoML and LangChain for custom LLM apps """ -from __future__ import annotations import logging as _logging import os as _os -import typing as _t +import pathlib as _pathlib import warnings as _warnings -from pathlib import Path as _Path - -import openllm_core - -from openllm_core._configuration import GenerationConfig as GenerationConfig -from openllm_core._configuration import LLMConfig as LLMConfig -from openllm_core._configuration import SamplingParams as SamplingParams -from openllm_core._schemas import GenerationInput as GenerationInput -from openllm_core._schemas import GenerationOutput as GenerationOutput -from openllm_core._schemas import MetadataOutput as MetadataOutput -from openllm_core.config import CONFIG_MAPPING as CONFIG_MAPPING -from openllm_core.config import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES -from openllm_core.config import AutoConfig as AutoConfig -from openllm_core.config import BaichuanConfig as BaichuanConfig -from openllm_core.config import ChatGLMConfig as ChatGLMConfig -from openllm_core.config import DollyV2Config as DollyV2Config -from openllm_core.config import FalconConfig as FalconConfig -from openllm_core.config import FlanT5Config as FlanT5Config -from openllm_core.config import GPTNeoXConfig as GPTNeoXConfig -from openllm_core.config import LlamaConfig as LlamaConfig -from openllm_core.config import MistralConfig as MistralConfig -from openllm_core.config import MPTConfig as MPTConfig -from openllm_core.config import OPTConfig as OPTConfig -from openllm_core.config import StableLMConfig as StableLMConfig -from openllm_core.config import StarCoderConfig as StarCoderConfig - -from . import exceptions as exceptions from . import utils as utils -if openllm_core.utils.DEBUG: - openllm_core.utils.set_debug_mode(True) - openllm_core.utils.set_quiet_mode(False) + +if utils.DEBUG: + utils.set_debug_mode(True) + utils.set_quiet_mode(False) _logging.basicConfig(level=_logging.NOTSET) else: # configuration for bitsandbytes before import @@ -64,68 +37,36 @@ else: 'ignore', message='Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated' ) -_import_structure: dict[str, list[str]] = { - 'exceptions': [], - 'client': ['HTTPClient', 'AsyncHTTPClient'], - 'bundle': [], - 'playground': [], - 'testing': [], - 'prompts': ['PromptTemplate'], - 'protocol': [], - 'utils': [], - '_deprecated': ['Runner'], - '_strategies': ['CascadingResourceStrategy', 'get_resource'], - 'entrypoints': ['mount_entrypoints'], - 'serialisation': ['ggml', 'transformers'], - 'cli._sdk': ['start', 'start_grpc', 'build', 'import_model', 'list_models'], - '_quantisation': ['infer_quantisation_config'], - '_llm': ['LLM', 'LLMRunner', 'LLMRunnable'], - '_generation': [ - 'StopSequenceCriteria', - 'StopOnTokens', - 'LogitsProcessorList', - 'StoppingCriteriaList', - 'prepare_logits_processor', - ], -} -COMPILED = _Path(__file__).suffix in ('.pyd', '.so') - -if _t.TYPE_CHECKING: - from . import bundle as bundle - from . import cli as cli - from . import client as client - from . import playground as playground - from . import serialisation as serialisation - from . import testing as testing - from . import utils as utils - from .client import HTTPClient as HTTPClient - from .client import AsyncHTTPClient as AsyncHTTPClient - from ._deprecated import Runner as Runner - from ._generation import LogitsProcessorList as LogitsProcessorList - from ._generation import StopOnTokens as StopOnTokens - from ._generation import StoppingCriteriaList as StoppingCriteriaList - from ._generation import StopSequenceCriteria as StopSequenceCriteria - from ._generation import prepare_logits_processor as prepare_logits_processor - from ._llm import LLM as LLM - from ._llm import LLMRunnable as LLMRunnable - from ._llm import LLMRunner as LLMRunner - from ._quantisation import infer_quantisation_config as infer_quantisation_config - from ._strategies import CascadingResourceStrategy as CascadingResourceStrategy - from ._strategies import get_resource as get_resource - from .cli._sdk import build as build - from .cli._sdk import import_model as import_model - from .cli._sdk import list_models as list_models - from .cli._sdk import start as start - from .cli._sdk import start_grpc as start_grpc - from .entrypoints import mount_entrypoints as mount_entrypoints - from .prompts import PromptTemplate as PromptTemplate - from .protocol import openai as openai - from .serialisation import ggml as ggml - from .serialisation import transformers as transformers +COMPILED = _pathlib.Path(__file__).suffix in ('.pyd', '.so') # NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__ -__lazy = openllm_core.utils.LazyModule( - __name__, globals()['__file__'], _import_structure, extra_objects={'COMPILED': COMPILED} +__lazy = utils.LazyModule( + __name__, + globals()['__file__'], + { + 'exceptions': [], + 'client': ['HTTPClient', 'AsyncHTTPClient'], + 'bundle': [], + 'playground': [], + 'testing': [], + 'protocol': [], + 'utils': [], + '_deprecated': ['Runner'], + '_strategies': ['CascadingResourceStrategy', 'get_resource'], + 'entrypoints': ['mount_entrypoints'], + 'serialisation': ['ggml', 'transformers'], + 'cli._sdk': ['start', 'start_grpc', 'build', 'import_model', 'list_models'], + '_quantisation': ['infer_quantisation_config'], + '_llm': ['LLM', 'LLMRunner', 'LLMRunnable'], + '_generation': [ + 'StopSequenceCriteria', + 'StopOnTokens', + 'LogitsProcessorList', + 'StoppingCriteriaList', + 'prepare_logits_processor', + ], + }, + extra_objects={'COMPILED': COMPILED}, ) __all__ = __lazy.__all__ __dir__ = __lazy.__dir__ diff --git a/openllm-python/src/openllm/__init__.pyi b/openllm-python/src/openllm/__init__.pyi new file mode 100644 index 00000000..7d444ede --- /dev/null +++ b/openllm-python/src/openllm/__init__.pyi @@ -0,0 +1,54 @@ +from openllm_core._configuration import GenerationConfig as GenerationConfig +from openllm_core._configuration import LLMConfig as LLMConfig +from openllm_core._configuration import SamplingParams as SamplingParams +from openllm_core._schemas import GenerationInput as GenerationInput +from openllm_core._schemas import GenerationOutput as GenerationOutput +from openllm_core._schemas import MetadataOutput as MetadataOutput +from openllm_core.config import CONFIG_MAPPING as CONFIG_MAPPING +from openllm_core.config import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES +from openllm_core.config import AutoConfig as AutoConfig +from openllm_core.config import BaichuanConfig as BaichuanConfig +from openllm_core.config import ChatGLMConfig as ChatGLMConfig +from openllm_core.config import DollyV2Config as DollyV2Config +from openllm_core.config import FalconConfig as FalconConfig +from openllm_core.config import FlanT5Config as FlanT5Config +from openllm_core.config import GPTNeoXConfig as GPTNeoXConfig +from openllm_core.config import LlamaConfig as LlamaConfig +from openllm_core.config import MistralConfig as MistralConfig +from openllm_core.config import MPTConfig as MPTConfig +from openllm_core.config import OPTConfig as OPTConfig +from openllm_core.config import StableLMConfig as StableLMConfig +from openllm_core.config import StarCoderConfig as StarCoderConfig +from . import exceptions as exceptions +from . import bundle as bundle +from . import cli as cli +from . import client as client +from . import playground as playground +from . import serialisation as serialisation +from . import testing as testing +from . import utils as utils +from ._deprecated import Runner as Runner +from ._generation import LogitsProcessorList as LogitsProcessorList +from ._generation import StopOnTokens as StopOnTokens +from ._generation import StoppingCriteriaList as StoppingCriteriaList +from ._generation import StopSequenceCriteria as StopSequenceCriteria +from ._generation import prepare_logits_processor as prepare_logits_processor +from ._llm import LLM as LLM +from ._llm import LLMRunnable as LLMRunnable +from ._llm import LLMRunner as LLMRunner +from ._quantisation import infer_quantisation_config as infer_quantisation_config +from ._strategies import CascadingResourceStrategy as CascadingResourceStrategy +from ._strategies import get_resource as get_resource +from .cli._sdk import build as build +from .cli._sdk import import_model as import_model +from .cli._sdk import list_models as list_models +from .cli._sdk import start as start +from .cli._sdk import start_grpc as start_grpc +from .client import AsyncHTTPClient as AsyncHTTPClient +from .client import HTTPClient as HTTPClient +from .entrypoints import mount_entrypoints as mount_entrypoints +from .protocol import openai as openai +from .serialisation import ggml as ggml +from .serialisation import transformers as transformers + +COMPILED: bool = ... diff --git a/openllm-python/src/openllm/_deprecated.py b/openllm-python/src/openllm/_deprecated.py index 7b333dca..104f8b16 100644 --- a/openllm-python/src/openllm/_deprecated.py +++ b/openllm-python/src/openllm/_deprecated.py @@ -1,4 +1,5 @@ from __future__ import annotations +import logging import os import typing as t import warnings @@ -6,39 +7,28 @@ import warnings import openllm from openllm_core._typing_compat import LiteralBackend +from openllm_core._typing_compat import ParamSpec from openllm_core.utils import first_not_none from openllm_core.utils import is_vllm_available -if t.TYPE_CHECKING: - from openllm_core import LLMConfig - from openllm_core._typing_compat import ParamSpec +P = ParamSpec('P') - from ._llm import LLMRunner - - P = ParamSpec('P') - -_object_setattr = object.__setattr__ +logger = logging.getLogger(__name__) -def _mark_deprecated(fn: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]: - _object_setattr(fn, '__deprecated__', True) - return fn - - -@_mark_deprecated def Runner( model_name: str, ensure_available: bool = True, init_local: bool = False, backend: LiteralBackend | None = None, - llm_config: LLMConfig | None = None, + llm_config: openllm.LLMConfig | None = None, **attrs: t.Any, -) -> LLMRunner[t.Any, t.Any]: +) -> openllm.LLMRunner[t.Any, t.Any]: """Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'. > [!WARNING] - > This method is now deprecated and in favor of 'openllm.LLM.runner' + > This method is now deprecated and in favor of 'openllm.LLM' ```python runner = openllm.Runner("dolly-v2") @@ -65,6 +55,10 @@ def Runner( if llm_config is None: llm_config = openllm.AutoConfig.for_model(model_name) + if not ensure_available: + logger.warning( + "'ensure_available=False' won't have any effect as LLM will always check to download the model on initialisation." + ) model_id = attrs.get('model_id', default=os.getenv('OPENLLM_MODEL_ID', llm_config['default_id'])) _RUNNER_MSG = f"""\ Using 'openllm.Runner' is now deprecated. Make sure to switch to the following syntax: @@ -99,10 +93,4 @@ def Runner( return llm.runner -_DEPRECATED = {k: v for k, v in locals().items() if getattr(v, '__deprecated__', False)} - -__all__ = list(_DEPRECATED) - - -def __dir__() -> list[str]: - return sorted(_DEPRECATED.keys()) +__all__ = ['Runner'] diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 1b21a1cb..14eaa6ba 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -9,6 +9,8 @@ import typing as t import attr import inflection import orjson +import torch +import transformers from huggingface_hub import hf_hub_download @@ -44,7 +46,6 @@ from openllm_core.utils import first_not_none from openllm_core.utils import flatten_attrs from openllm_core.utils import generate_hash_from_file from openllm_core.utils import is_peft_available -from openllm_core.utils import is_torch_available from openllm_core.utils import resolve_filepath from openllm_core.utils import validate_is_path @@ -57,8 +58,6 @@ from .serialisation.constants import PEFT_CONFIG_NAME if t.TYPE_CHECKING: import peft - import torch - import transformers from bentoml._internal.runner.runnable import RunnableMethod from bentoml._internal.runner.runner import RunnerMethod @@ -68,8 +67,6 @@ if t.TYPE_CHECKING: from openllm_core.utils.representation import ReprArgs else: - transformers = LazyLoader('transformers', globals(), 'transformers') - torch = LazyLoader('torch', globals(), 'torch') peft = LazyLoader('peft', globals(), 'peft') ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['peft.PeftConfig', str]]] @@ -146,12 +143,6 @@ class LLM(t.Generic[M, T], ReprMixin): __llm_adapter_map__: t.Optional[ResolvedAdapterMap] = None __llm_trust_remote_code__: bool = False - device: 'torch.device | None' = None - - def __attrs_post_init__(self) -> None: - if self.__llm_backend__ == 'pt': - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - def __init__( self, model_id: str, @@ -233,10 +224,7 @@ class LLM(t.Generic[M, T], ReprMixin): self._bentomodel = model @apply(lambda val: tuple(str.lower(i) if i else i for i in val)) - def _make_tag_components( - self, model_id: str, model_version: str | None, backend: LiteralBackend - ) -> tuple[str, str | None]: - """Return a valid tag name (---) and its tag version.""" + def _make_tag_components(self, model_id, model_version, backend) -> tuple[str, str | None]: model_id, *maybe_revision = model_id.rsplit(':') if len(maybe_revision) > 0: if model_version is not None: @@ -254,16 +242,16 @@ class LLM(t.Generic[M, T], ReprMixin): return f'{backend}-{normalise_model_name(model_id)}', model_version # yapf: disable - def __setattr__(self,attr: str,value: t.Any)->None: + def __setattr__(self,attr,value): if attr in _reserved_namespace:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.') super().__setattr__(attr,value) @property def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'} - def __repr_args__(self) -> ReprArgs: - yield 'model_id', self._model_id if not self._local else self.tag.name - yield 'revision', self._revision if self._revision else self.tag.version - yield 'backend', self.__llm_backend__ - yield 'type', self.llm_type + def __repr_args__(self): + yield 'model_id',self._model_id if not self._local else self.tag.name + yield 'revision',self._revision if self._revision else self.tag.version + yield 'backend',self.__llm_backend__ + yield 'type',self.llm_type @property def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'} @property @@ -277,11 +265,7 @@ class LLM(t.Generic[M, T], ReprMixin): @property def tag(self)->bentoml.Tag:return self._tag @property - def bentomodel(self)->bentoml.Model:return self._bentomodel - @property - def config(self)->LLMConfig: - if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs) - return self.__llm_config__ + def bentomodel(self)->bentoml.Model:return openllm.serialisation.get(self) @property def quantization_config(self)->transformers.BitsAndBytesConfig|transformers.GPTQConfig|transformers.AwqConfig: if self.__llm_quantization_config__ is None: @@ -300,31 +284,53 @@ class LLM(t.Generic[M, T], ReprMixin): def identifying_params(self)->DictStrAny:return {'configuration': self.config.model_dump_json().decode(),'model_ids': orjson.dumps(self.config['model_ids']).decode(),'model_id': self.model_id} @property def llm_parameters(self)->tuple[tuple[tuple[t.Any,...],DictStrAny],DictStrAny]:return (self._model_decls,self._model_attrs),self._tokenizer_attrs - # yapf: enable - - # NOTE: The section helps with fine-tuning. + # NOTE: This section is the actual model, tokenizer, and config reference here. + @property + def config(self)->LLMConfig: + if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs) + return self.__llm_config__ + @property + def tokenizer(self)->T: + if self.__llm_tokenizer__ is None:self.__llm_tokenizer__=openllm.serialisation.load_tokenizer(self,**self.llm_parameters[-1]) + return self.__llm_tokenizer__ + @property + def runner(self)->LLMRunner[M, T]: + if self.__llm_runner__ is None:self.__llm_runner__=_RunnerFactory(self) + return self.__llm_runner__ + @property + def model(self)->M: + if self.__llm_model__ is None: + model=openllm.serialisation.load_model(self,*self._model_decls,**self._model_attrs) + # If OOM, then it is probably you don't have enough VRAM to run this model. + if self.__llm_backend__ == 'pt': + loaded_in_kbit = getattr(model,'is_loaded_in_8bit',False) or getattr(model,'is_loaded_in_4bit',False) or getattr(model,'is_quantized',False) + if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit: + try: model = model.to('cuda') + except Exception as err: raise OpenLLMException(f'Failed to load model into GPU: {err}\n. See https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.') from err + if self.has_adapters: + logger.debug('Applying the following adapters: %s', self.adapter_map) + for adapter_dict in self.adapter_map.values(): + for adapter_name, (peft_config, peft_model_id) in adapter_dict.items(): + model.load_adapter(peft_model_id, adapter_name, peft_config=peft_config) + self.__llm_model__ = model + return self.__llm_model__ @property def adapter_map(self) -> ResolvedAdapterMap: try: import peft as _ # noqa: F401 except ImportError as err: - raise MissingDependencyError( - "Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'" - ) from err - if not self.has_adapters: - raise AttributeError('Adapter map is not available.') + raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'") from err + if not self.has_adapters: raise AttributeError('Adapter map is not available.') assert self._adapter_map is not None if self.__llm_adapter_map__ is None: _map: ResolvedAdapterMap = {k: {} for k in self._adapter_map} for adapter_type, adapter_tuple in self._adapter_map.items(): - base = first_not_none( - self.config['fine_tune_strategies'].get(adapter_type), - default=self.config.make_fine_tune_config(adapter_type), - ) + base = first_not_none(self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type)) for adapter in adapter_tuple: _map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id) self.__llm_adapter_map__ = _map return self.__llm_adapter_map__ + # yapf: enable def prepare_for_training( self, adapter_type: AdapterType = 'lora', use_gradient_checking: bool = True, **attrs: t.Any @@ -343,56 +349,11 @@ class LLM(t.Generic[M, T], ReprMixin): raise ValueError('Adapter should not be specified when fine-tuning.') model = get_peft_model( prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking), peft_config - ) # type: ignore[no-untyped-call] + ) if DEBUG: model.print_trainable_parameters() # type: ignore[no-untyped-call] return model, self.tokenizer - @property - def model(self) -> M: - if self.__llm_model__ is None: - model = openllm.serialisation.load_model(self, *self._model_decls, **self._model_attrs) - # If OOM, then it is probably you don't have enough VRAM to run this model. - if self.__llm_backend__ == 'pt': - if is_torch_available(): - loaded_in_kbit = ( - getattr(model, 'is_loaded_in_8bit', False) - or getattr(model, 'is_loaded_in_4bit', False) - or getattr(model, 'is_quantized', False) - ) - if ( - torch.cuda.is_available() - and torch.cuda.device_count() == 1 - and not loaded_in_kbit - and not isinstance(model, transformers.Pipeline) - ): - try: - model = model.to('cuda') - except Exception as err: - raise OpenLLMException( - f'Failed to load {self} into GPU: {err}\nTip: If you run into OOM issue, maybe try different offload strategy. See https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.' - ) from err - if self.has_adapters: - logger.debug('Applying the following adapters: %s', self.adapter_map) - for adapter_dict in self.adapter_map.values(): - for adapter_name, (peft_config, peft_model_id) in adapter_dict.items(): - model.load_adapter(peft_model_id, adapter_name, peft_config=peft_config) - self.__llm_model__ = model - return self.__llm_model__ - - @property - def tokenizer(self) -> T: - # NOTE: the signature of load_tokenizer here is the wrapper under _wrapped_load_tokenizer - if self.__llm_tokenizer__ is None: - self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1]) - return self.__llm_tokenizer__ - - @property - def runner(self) -> LLMRunner[M, T]: - if self.__llm_runner__ is None: - self.__llm_runner__ = _RunnerFactory(self) - return self.__llm_runner__ - async def generate( self, prompt: str | None, diff --git a/openllm-python/src/openllm/cli/entrypoint.py b/openllm-python/src/openllm/cli/entrypoint.py index 79e9b219..ae08090e 100644 --- a/openllm-python/src/openllm/cli/entrypoint.py +++ b/openllm-python/src/openllm/cli/entrypoint.py @@ -23,7 +23,6 @@ bentomodel = openllm.import_model('mistralai/Mistral-7B-v0.1') from __future__ import annotations import enum import functools -import importlib.util import inspect import itertools import logging @@ -31,6 +30,7 @@ import os import platform import random import subprocess +import threading import time import traceback import typing as t @@ -65,6 +65,7 @@ from openllm_core._typing_compat import LiteralString from openllm_core._typing_compat import NotRequired from openllm_core._typing_compat import ParamSpec from openllm_core._typing_compat import Self +from openllm_core._typing_compat import TypeGuard from openllm_core.config import CONFIG_MAPPING from openllm_core.utils import DEBUG_ENV_VAR from openllm_core.utils import OPTIONAL_DEPENDENCIES @@ -451,84 +452,31 @@ def start_command( # XXX: currently, theres no development args in bentoml.Server. To be fixed upstream. development = server_attrs.pop('development') server_attrs.setdefault('production', not development) - wpr = first_not_none(workers_per_resource, default=config['workers_per_resource']) - if isinstance(wpr, str): - if wpr == 'round_robin': - wpr = 1.0 - elif wpr == 'conserved': - if device and openllm.utils.device_count() == 0: - termui.echo('--device will have no effect as there is no GPUs available', fg='yellow') - wpr = 1.0 - else: - available_gpu = len(device) if device else openllm.utils.device_count() - wpr = 1.0 if available_gpu == 0 else float(1 / available_gpu) - else: - wpr = float(wpr) - elif isinstance(wpr, int): - wpr = float(wpr) - requirements = llm.config['requirements'] - if requirements is not None and len(requirements) > 0: - missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None] - if len(missing_requirements) > 0: - termui.echo(f'Make sure to have the following dependencies available: {missing_requirements}', fg='yellow') - - start_env = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy()) - start_env.update( - { - 'OPENLLM_MODEL_ID': model_id, - 'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()), - 'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()), - 'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(), - 'OPENLLM_SERIALIZATION': serialisation, - 'OPENLLM_BACKEND': llm.__llm_backend__, - 'OPENLLM_CONFIG': llm.config.model_dump_json(flatten=True).decode(), - } + start_env = process_environ( + config, + server_timeout, + process_workers_per_resource(first_not_none(workers_per_resource, default=config['workers_per_resource']), device), + device, + cors, + model_id, + adapter_map, + serialisation, + llm, + system_message, + prompt_template, ) - if llm._quantise: - start_env['OPENLLM_QUANTIZE'] = str(llm._quantise) - if system_message: - start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message - if prompt_template: - start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template server = bentoml.HTTPServer('_service:svc', **server_attrs) - openllm.utils.analytics.track_start_init(llm.config) + openllm.utils.analytics.track_start_init(config) - def next_step(adapter_map: DictStrAny | None, caught_exception: bool = False) -> None: - if caught_exception: - return - cmd_name = f'openllm build {model_id}' - if llm._quantise: - cmd_name += f' --quantize {llm._quantise}' - cmd_name += f' --serialization {serialisation}' - if adapter_map is not None: - cmd_name += ' ' + ' '.join( - [ - f'--adapter-id {s}' - for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()] - ] - ) - if not openllm.utils.get_quiet_mode(): - termui.info(f"\n\nπŸš€ Next step: run '{cmd_name}' to create a BentoLLM for '{model_id}'") - - _exception = False - if return_process: - server.start(env=start_env, text=True) - if server.process is None: - raise click.ClickException('Failed to start the server.') - return server.process - else: - try: - server.start(env=start_env, text=True, blocking=True) - except KeyboardInterrupt: - _exception = True - except Exception as err: - termui.error(f'Error caught while running LLM Server:\n{err}') - _exception = True - raise - else: - next_step(adapter_map, _exception) + try: + build_bento_instruction(llm, model_id, serialisation, adapter_map) + it = run_server(server.args, start_env, return_process=return_process) + if return_process: + return it + except KeyboardInterrupt: + pass # NOTE: Return the configuration for telemetry purposes. return config @@ -550,7 +498,9 @@ def start_command( help='Deprecated. Use positional argument instead.', ) @start_decorator(serve_grpc=True) +@click.pass_context def start_grpc_command( + ctx: click.Context, model_id: str, server_timeout: int, model_version: str | None, @@ -626,7 +576,61 @@ def start_grpc_command( # XXX: currently, theres no development args in bentoml.Server. To be fixed upstream. development = server_attrs.pop('development') server_attrs.setdefault('production', not development) - wpr = first_not_none(workers_per_resource, default=config['workers_per_resource']) + + start_env = process_environ( + config, + server_timeout, + process_workers_per_resource(first_not_none(workers_per_resource, default=config['workers_per_resource']), device), + device, + cors, + model_id, + adapter_map, + serialisation, + llm, + system_message, + prompt_template, + ) + + server = bentoml.GrpcServer('_service:svc', **server_attrs) + openllm.utils.analytics.track_start_init(llm.config) + + try: + build_bento_instruction(llm, model_id, serialisation, adapter_map) + it = run_server(server.args, start_env, return_process=return_process) + if return_process: + return it + except KeyboardInterrupt: + pass + + # NOTE: Return the configuration for telemetry purposes. + return config + + +def process_environ( + config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, system_message, prompt_template +) -> t.Dict[str, t.Any]: + environ = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy()) + environ.update( + { + 'OPENLLM_MODEL_ID': model_id, + 'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()), + 'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()), + 'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(), + 'OPENLLM_SERIALIZATION': serialisation, + 'OPENLLM_BACKEND': llm.__llm_backend__, + 'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(), + } + ) + if llm._quantise: + environ['OPENLLM_QUANTIZE'] = str(llm._quantise) + if system_message: + environ['OPENLLM_SYSTEM_MESSAGE'] = system_message + if prompt_template: + environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template + return environ + + +def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]) -> TypeGuard[float]: if isinstance(wpr, str): if wpr == 'round_robin': wpr = 1.0 @@ -641,76 +645,82 @@ def start_grpc_command( wpr = float(wpr) elif isinstance(wpr, int): wpr = float(wpr) + return wpr - requirements = llm.config['requirements'] - if requirements is not None and len(requirements) > 0: - missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None] - if len(missing_requirements) > 0: - termui.warning(f'Make sure to have the following dependencies available: {missing_requirements}') - start_env = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy()) - start_env.update( - { - 'OPENLLM_MODEL_ID': model_id, - 'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()), - 'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()), - 'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(), - 'OPENLLM_SERIALIZATION': serialisation, - 'OPENLLM_BACKEND': llm.__llm_backend__, - 'OPENLLM_CONFIG': llm.config.model_dump_json().decode(), - } - ) +def build_bento_instruction(llm, model_id, serialisation, adapter_map): + cmd_name = f'openllm build {model_id}' if llm._quantise: - start_env['OPENLLM_QUANTIZE'] = str(llm._quantise) - if system_message: - start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message - if prompt_template: - start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template + cmd_name += f' --quantize {llm._quantise}' + cmd_name += f' --serialization {serialisation}' + if adapter_map is not None: + cmd_name += ' ' + ' '.join( + [ + f'--adapter-id {s}' + for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()] + ] + ) + if not openllm.utils.get_quiet_mode(): + termui.info(f"πŸš€Tip: run '{cmd_name}' to create a BentoLLM for '{model_id}'") - server = bentoml.GrpcServer('_service:svc', **server_attrs) - openllm.utils.analytics.track_start_init(llm.config) - def next_step(adapter_map: DictStrAny | None, caught_exception: bool = False) -> None: - if caught_exception: - return - cmd_name = f'openllm build {model_id}' - if llm._quantise: - cmd_name += f' --quantize {llm._quantise}' - cmd_name += f' --serialization {serialisation}' - if adapter_map is not None: - cmd_name += ' ' + ' '.join( - [ - f'--adapter-id {s}' - for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()] - ] - ) - if not openllm.utils.get_quiet_mode(): - termui.info(f"\nπŸš€ Next step: run '{cmd_name}' to create a BentoLLM for '{model_id}'") - - _exception = False - if return_process: - server.start(env=start_env, text=True) - if server.process is None: - raise click.ClickException('Failed to start the server.') - return server.process +def pretty_print(line: str): + if 'WARNING' in line: + caller = termui.warning + elif 'INFO' in line: + caller = termui.info + elif 'DEBUG' in line: + caller = termui.debug + elif 'ERROR' in line: + caller = termui.error else: - try: - server.start(env=start_env, text=True, blocking=True) - except KeyboardInterrupt: - _exception = True - except Exception as err: - termui.error(f'Error caught while running LLM Server:\n{err}') - _exception = True - raise - else: - next_step(adapter_map, _exception) + caller = caller = functools.partial(termui.echo, fg=None) + caller(line.strip()) - # NOTE: Return the configuration for telemetry purposes. - return config + +def handle(stream, stop_event): + try: + for line in iter(stream.readline, ''): + if stop_event.is_set(): + break + pretty_print(line) + finally: + stream.close() + + +def run_server(args, env, return_process=False) -> subprocess.Popen[bytes] | int: + process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, text=True) + if return_process: + return process + stop_event = threading.Event() + # yapf: disable + stdout, stderr = threading.Thread(target=handle, args=(process.stdout, stop_event)), threading.Thread(target=handle, args=(process.stderr, stop_event)) + stdout.start(); stderr.start() + + try: + process.wait() + except KeyboardInterrupt: + stop_event.set() + process.terminate() + try: + process.wait(0.1) + except subprocess.TimeoutExpired: + # not sure if the process exits cleanly + process.kill() + raise + finally: + stop_event.set() + stdout.join(); stderr.join() + if process.poll() is not None: process.kill() + stdout.join(); stderr.join() + # yapf: disable + + return process.returncode class ItemState(enum.Enum): NOT_FOUND = 'NOT_FOUND' + ADDED = 'ADDED' EXISTS = 'EXISTS' OVERWRITE = 'OVERWRITE' @@ -808,6 +818,7 @@ def import_command( model = openllm.serialisation.import_model(llm, trust_remote_code=llm.trust_remote_code) if llm.__llm_backend__ == 'pt' and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache() + state = ItemState.ADDED response = ImportModelOutput(state=state, backend=llm.__llm_backend__, tag=str(model.tag)) termui.echo(orjson.dumps(response).decode(), fg='white') return response @@ -1052,6 +1063,8 @@ def build_command( container_registry=container_registry, container_version_strategy=container_version_strategy, ) + if state != ItemState.OVERWRITE: + state = ItemState.ADDED except Exception as err: traceback.print_exc() raise click.ClickException('Exception caught while building BentoLLM:\n' + str(err)) from err @@ -1074,10 +1087,10 @@ def build_command( if machine: termui.echo(f'__object__:{orjson.dumps(response).decode()}\n\n', fg='white') elif not get_quiet_mode() and (not push or not containerize): - if not overwrite: - termui.warning(f"Bento for '{model_id}' already exists [{bento}]. To overwrite it pass '--overwrite'.\n") - elif state != ItemState.EXISTS: + if state != ItemState.EXISTS: termui.info(f"Successfully built Bento '{bento.tag}'.\n") + elif not overwrite: + termui.warning(f"Bento for '{model_id}' already exists [{bento}]. To overwrite it pass '--overwrite'.\n") if not get_debug_mode(): termui.echo(OPENLLM_FIGLET) termui.echo('\nπŸ“– Next steps:\n\n', nl=False) diff --git a/openllm-python/src/openllm/cli/termui.py b/openllm-python/src/openllm/cli/termui.py index 66ce3081..23ea3d9c 100644 --- a/openllm-python/src/openllm/cli/termui.py +++ b/openllm-python/src/openllm/cli/termui.py @@ -36,6 +36,16 @@ class Level(enum.IntEnum): Level.CRITICAL: 'red', }[self] + @classmethod + def from_logging_level(cls, level: int) -> Level: + return { + logging.DEBUG: Level.DEBUG, + logging.INFO: Level.INFO, + logging.WARNING: Level.WARNING, + logging.ERROR: Level.ERROR, + logging.CRITICAL: Level.CRITICAL, + }[level] + class JsonLog(t.TypedDict): log_level: Level @@ -43,13 +53,7 @@ class JsonLog(t.TypedDict): def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None: - def caller(text: str) -> None: - if get_debug_mode(): - logger.log(level.value, text) - else: - echo(JsonLog(log_level=level, content=content), json=True, fg=fg) - - caller(orjson.dumps(JsonLog(log_level=level, content=content)).decode()) + echo(orjson.dumps(JsonLog(log_level=level, content=content)).decode(), fg=fg, json=True) warning = functools.partial(log, level=Level.WARNING) @@ -61,18 +65,17 @@ notset = functools.partial(log, level=Level.NOTSET) def echo(text: t.Any, fg: str | None = None, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None: - if json and not isinstance(text, dict): - raise TypeError('text must be a dict') if json: + text = orjson.loads(text) if 'content' in text and 'log_level' in text: - content = t.cast(DictStrAny, text)['content'] - fg = t.cast(Level, text['log_level']).color + content = text['content'] + fg = Level.from_logging_level(text['log_level']).color else: content = orjson.dumps(text).decode() fg = Level.INFO.color if not get_debug_mode() else Level.DEBUG.color else: content = t.cast(str, text) - attrs['fg'] = fg if not get_debug_mode() else None + attrs['fg'] = fg if not get_quiet_mode(): t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(content, **attrs) diff --git a/openllm-python/src/openllm/client.py b/openllm-python/src/openllm/client.py index 1ededf04..91a95db4 100644 --- a/openllm-python/src/openllm/client.py +++ b/openllm-python/src/openllm/client.py @@ -1,27 +1,9 @@ -"""OpenLLM Python client. - -```python -client = openllm.client.HTTPClient("http://localhost:8080") -client.query("What is the difference between gather and scatter?") -``` -""" - -from __future__ import annotations -import typing as t - -import openllm_client +import openllm_client as _client -if t.TYPE_CHECKING: - from openllm_client import AsyncHTTPClient as AsyncHTTPClient - from openllm_client import HTTPClient as HTTPClient - # from openllm_client import AsyncGrpcClient as AsyncGrpcClient - # from openllm_client import GrpcClient as GrpcClient +def __dir__(): + return sorted(dir(_client)) -def __dir__() -> t.Sequence[str]: - return sorted(dir(openllm_client)) - - -def __getattr__(it: str) -> t.Any: - return getattr(openllm_client, it) +def __getattr__(it): + return getattr(_client, it) diff --git a/openllm-python/src/openllm/client.pyi b/openllm-python/src/openllm/client.pyi new file mode 100644 index 00000000..f70f192a --- /dev/null +++ b/openllm-python/src/openllm/client.pyi @@ -0,0 +1,10 @@ +"""OpenLLM Python client. + +```python +client = openllm.client.HTTPClient("http://localhost:8080") +client.query("What is the difference between gather and scatter?") +``` +""" + +from openllm_client import HTTPClient as HTTPClient +from openllm_client import AsyncHTTPClient as AsyncHTTPClient diff --git a/openllm-python/src/openllm/exceptions.py b/openllm-python/src/openllm/exceptions.py index 5c2d94eb..e9798770 100644 --- a/openllm-python/src/openllm/exceptions.py +++ b/openllm-python/src/openllm/exceptions.py @@ -1,7 +1,3 @@ -"""Base exceptions for OpenLLM. This extends BentoML exceptions.""" - -from __future__ import annotations - from openllm_core.exceptions import Error as Error from openllm_core.exceptions import FineTuneStrategyNotSupportedError as FineTuneStrategyNotSupportedError from openllm_core.exceptions import ForbiddenAttributeError as ForbiddenAttributeError diff --git a/openllm-python/src/openllm/prompts.py b/openllm-python/src/openllm/prompts.py deleted file mode 100644 index 35987e98..00000000 --- a/openllm-python/src/openllm/prompts.py +++ /dev/null @@ -1,3 +0,0 @@ -from __future__ import annotations - -from openllm_core.prompts import PromptTemplate as PromptTemplate diff --git a/openllm-python/tests/conftest.py b/openllm-python/tests/conftest.py index ddb07982..58bc5aa2 100644 --- a/openllm-python/tests/conftest.py +++ b/openllm-python/tests/conftest.py @@ -28,12 +28,7 @@ def parametrise_local_llm( pytest.skip(f"'{model}' is not yet supported in framework testing.") backends: tuple[LiteralBackend, ...] = ('pt',) for backend, prompt in itertools.product(backends, _PROMPT_MAPPING.keys()): - yield ( - prompt, - openllm.Runner( - model, model_id=_MODELING_MAPPING[model], ensure_available=True, backend=backend, init_local=True - ), - ) + yield prompt, openllm.LLM(_MODELING_MAPPING[model], backend=backend) def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: diff --git a/openllm-python/tests/models_test.py b/openllm-python/tests/models_test.py index 80d526fe..57e07034 100644 --- a/openllm-python/tests/models_test.py +++ b/openllm-python/tests/models_test.py @@ -11,20 +11,20 @@ if t.TYPE_CHECKING: @pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI') def test_flan_t5_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]): - assert llm(prompt) + assert llm.generate(prompt) - assert llm(prompt, temperature=0.8, top_p=0.23) + assert llm.generate(prompt, temperature=0.8, top_p=0.23) @pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI') def test_opt_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]): - assert llm(prompt) + assert llm.generate(prompt) - assert llm(prompt, temperature=0.9, top_k=8) + assert llm.generate(prompt, temperature=0.9, top_k=8) @pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI') def test_baichuan_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]): - assert llm(prompt) + assert llm.generate(prompt) - assert llm(prompt, temperature=0.95) + assert llm.generate(prompt, temperature=0.95)