From 52a44b1bfa2c7f3c3ea6ae9e3e2ee79092355fd0 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Wed, 22 Nov 2023 21:51:51 -0500 Subject: [PATCH] chore: cleanup loader (#729) Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --- openllm-client/src/openllm_client/_utils.pyi | 1 + .../src/openllm_core/utils/__init__.py | 184 ++++-------------- .../src/openllm_core/utils/__init__.pyi | 1 + .../src/openllm_core/utils/import_utils.py | 70 ++----- openllm-python/src/openllm/_runners.py | 80 ++++---- .../serialisation/transformers/__init__.py | 44 +++-- openllm-python/src/openllm/utils.pyi | 3 +- openllm-python/src/openllm_cli/entrypoint.py | 6 +- 8 files changed, 125 insertions(+), 264 deletions(-) diff --git a/openllm-client/src/openllm_client/_utils.pyi b/openllm-client/src/openllm_client/_utils.pyi index 90738487..51cfaad9 100644 --- a/openllm-client/src/openllm_client/_utils.pyi +++ b/openllm-client/src/openllm_client/_utils.pyi @@ -41,6 +41,7 @@ from openllm_core.utils.import_utils import ( is_grpc_available as is_grpc_available, is_jupyter_available as is_jupyter_available, is_jupytext_available as is_jupytext_available, + is_flash_attn_2_available as is_flash_attn_2_available, is_notebook_available as is_notebook_available, is_peft_available as is_peft_available, is_torch_available as is_torch_available, diff --git a/openllm-core/src/openllm_core/utils/__init__.py b/openllm-core/src/openllm_core/utils/__init__.py index 42c7e3aa..f1db9f89 100644 --- a/openllm-core/src/openllm_core/utils/__init__.py +++ b/openllm-core/src/openllm_core/utils/__init__.py @@ -1,7 +1,6 @@ from __future__ import annotations import functools, hashlib, logging, logging.config, os, sys, types, uuid, typing as t from pathlib import Path as _Path - from . import import_utils as iutils, pkg from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES from .lazy import LazyLoader as LazyLoader, LazyModule as LazyModule, VersionInfo as VersionInfo @@ -16,107 +15,60 @@ DEV_DEBUG_VAR = 'DEBUG' # equivocal setattr to save one lookup per assignment _object_setattr = object.__setattr__ logger = logging.getLogger(__name__) - - @functools.lru_cache(maxsize=1) def _WithArgsTypes() -> tuple[type[t.Any], ...]: try: - from typing import GenericAlias as _TypingGenericAlias # type: ignore + from typing import GenericAlias as _TypingGenericAlias except ImportError: - _TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) + _TypingGenericAlias = () # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) # _GenericAlias is the actual GenericAlias implementation - return ( - (_TypingGenericAlias,) if sys.version_info < (3, 10) else (t._GenericAlias, types.GenericAlias, types.UnionType) - ) # type: ignore - - + return (_TypingGenericAlias,) if sys.version_info < (3, 10) else (t._GenericAlias, types.GenericAlias, types.UnionType) def lenient_issubclass(cls, class_or_tuple): try: return isinstance(cls, type) and issubclass(cls, class_or_tuple) except TypeError: - if isinstance(cls, _WithArgsTypes()): - return False + if isinstance(cls, _WithArgsTypes()): return False raise - - def resolve_user_filepath(filepath, ctx): _path = os.path.expanduser(os.path.expandvars(filepath)) - if os.path.exists(_path): - return os.path.realpath(_path) + if os.path.exists(_path): return os.path.realpath(_path) # Try finding file in ctx if provided if ctx: _path = os.path.expanduser(os.path.join(ctx, filepath)) - if os.path.exists(_path): - return os.path.realpath(_path) + if os.path.exists(_path): return os.path.realpath(_path) raise FileNotFoundError(f'file {filepath} not found') - - # this is the supress version of resolve_user_filepath def resolve_filepath(path, ctx=None): try: return resolve_user_filepath(path, ctx) except FileNotFoundError: return path - - def check_bool_env(env, default=True): v = os.getenv(env, default=str(default)).upper() - if v.isdigit(): - return bool(int(v)) # special check for digits + if v.isdigit(): return bool(int(v)) # special check for digits return v in ENV_VARS_TRUE_VALUES - - -def calc_dir_size(path): - return sum(f.stat().st_size for f in _Path(path).glob('**/*') if f.is_file()) - - +def calc_dir_size(path): return sum(f.stat().st_size for f in _Path(path).glob('**/*') if f.is_file()) @functools.lru_cache(maxsize=128) -def generate_hash_from_file(f, algorithm='sha1'): - return str(getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()) - - +def generate_hash_from_file(f, algorithm='sha1'): return str(getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()) def getenv(env, default=None, var=None): env_key = {env.upper(), f'OPENLLM_{env.upper()}'} - if var is not None: - env_key = set(var) | env_key - + if var is not None: env_key = set(var) | env_key def callback(k: str) -> t.Any: _var = os.getenv(k) - if _var and k.startswith('OPENLLM_'): - logger.warning("Using '%s' environment is deprecated, use '%s' instead.", k.upper(), k[8:].upper()) + if _var and k.startswith('OPENLLM_'): logger.warning("Using '%s' environment is deprecated, use '%s' instead.", k.upper(), k[8:].upper()) return _var - return first_not_none(*(callback(k) for k in env_key), default=default) - - -def field_env_key(key, suffix=None): - return '_'.join(filter(None, map(str.upper, ['OPENLLM', suffix.strip('_') if suffix else '', key]))) - - -def get_debug_mode(): - return check_bool_env(DEBUG_ENV_VAR, False) if (not DEBUG and DEBUG_ENV_VAR in os.environ) else DEBUG - - +def field_env_key(key, suffix=None): return '_'.join(filter(None, map(str.upper, ['OPENLLM', suffix.strip('_') if suffix else '', key]))) +def get_debug_mode(): return check_bool_env(DEBUG_ENV_VAR, False) if (not DEBUG and DEBUG_ENV_VAR in os.environ) else DEBUG def get_quiet_mode(): - if QUIET_ENV_VAR in os.environ: - return check_bool_env(QUIET_ENV_VAR, False) - if DEBUG: - return False + if QUIET_ENV_VAR in os.environ: return check_bool_env(QUIET_ENV_VAR, False) + if DEBUG: return False return False - - -def get_disable_warnings(): - return check_bool_env(WARNING_ENV_VAR, False) - - +def get_disable_warnings(): return check_bool_env(WARNING_ENV_VAR, False) def set_disable_warnings(disable=True): - if disable: - os.environ[WARNING_ENV_VAR] = str(disable) - - + if disable: os.environ[WARNING_ENV_VAR] = str(disable) def set_debug_mode(enabled, level=1): - if enabled: - os.environ[DEV_DEBUG_VAR] = str(level) + if enabled: os.environ[DEV_DEBUG_VAR] = str(level) os.environ.update( { DEBUG_ENV_VAR: str(enabled), @@ -126,132 +78,79 @@ def set_debug_mode(enabled, level=1): } ) set_disable_warnings(not enabled) - - def set_quiet_mode(enabled): os.environ.update( {QUIET_ENV_VAR: str(enabled), DEBUG_ENV_VAR: str(not enabled), _GRPC_DEBUG_ENV_VAR: 'NONE', 'CT2_VERBOSE': '-1'} ) set_disable_warnings(enabled) - - -def gen_random_uuid(prefix: str | None = None) -> str: - return '-'.join([prefix or 'openllm', str(uuid.uuid4().hex)]) - - +def gen_random_uuid(prefix: str | None = None) -> str: return '-'.join([prefix or 'openllm', str(uuid.uuid4().hex)]) # NOTE: `compose` any number of unary functions into a single unary function # compose(f, g, h)(x) == f(g(h(x))); compose(f, g, h)(x, y, z) == f(g(h(x, y, z))) -def compose(*funcs): - return functools.reduce(lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)), funcs) - - +def compose(*funcs): return functools.reduce(lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)), funcs) # NOTE: `apply` a transform function that is invoked on results returned from the decorated function # apply(reversed)(func)(*args, **kwargs) == reversed(func(*args, **kwargs)) -def apply(transform): - return lambda func: functools.wraps(func)(compose(transform, func)) - - -def validate_is_path(maybe_path): - return os.path.exists(os.path.dirname(resolve_filepath(maybe_path))) - - -def first_not_none(*args, default=None): - return next((arg for arg in args if arg is not None), default) - - +def apply(transform): return lambda func: functools.wraps(func)(compose(transform, func)) +def validate_is_path(maybe_path): return os.path.exists(os.path.dirname(resolve_filepath(maybe_path))) +def first_not_none(*args, default=None): return next((arg for arg in args if arg is not None), default) def generate_context(framework_name): from bentoml._internal.models.model import ModelContext - framework_versions = { 'transformers': pkg.get_pkg_version('transformers'), 'safetensors': pkg.get_pkg_version('safetensors'), 'optimum': pkg.get_pkg_version('optimum'), 'accelerate': pkg.get_pkg_version('accelerate'), } - if iutils.is_torch_available(): - framework_versions['torch'] = pkg.get_pkg_version('torch') - if iutils.is_ctranslate_available(): - framework_versions['ctranslate2'] = pkg.get_pkg_version('ctranslate2') - if iutils.is_vllm_available(): - framework_versions['vllm'] = pkg.get_pkg_version('vllm') - if iutils.is_autoawq_available(): - framework_versions['autoawq'] = pkg.get_pkg_version('autoawq') - if iutils.is_autogptq_available(): - framework_versions['autogptq'] = pkg.get_pkg_version('auto_gptq') - if iutils.is_bentoml_available(): - framework_versions['bentoml'] = pkg.get_pkg_version('bentoml') - if iutils.is_triton_available(): - framework_versions['triton'] = pkg.get_pkg_version('triton') + if iutils.is_torch_available(): framework_versions['torch'] = pkg.get_pkg_version('torch') + if iutils.is_ctranslate_available(): framework_versions['ctranslate2'] = pkg.get_pkg_version('ctranslate2') + if iutils.is_vllm_available(): framework_versions['vllm'] = pkg.get_pkg_version('vllm') + if iutils.is_autoawq_available(): framework_versions['autoawq'] = pkg.get_pkg_version('autoawq') + if iutils.is_autogptq_available(): framework_versions['autogptq'] = pkg.get_pkg_version('auto_gptq') + if iutils.is_bentoml_available(): framework_versions['bentoml'] = pkg.get_pkg_version('bentoml') + if iutils.is_triton_available(): framework_versions['triton'] = pkg.get_pkg_version('triton') + if iutils.is_flash_attn_2_available(): framework_versions['flash_attn'] = pkg.get_pkg_version('flash_attn') return ModelContext(framework_name=framework_name, framework_versions=framework_versions) - - @functools.lru_cache(maxsize=1) def in_notebook(): try: - from IPython.core.getipython import get_ipython - - return 'IPKernelApp' in get_ipython().config + from IPython.core.getipython import get_ipython; return 'IPKernelApp' in get_ipython().config except Exception: return False - - -_TOKENIZER_PREFIX = '_tokenizer_' - - def flatten_attrs(**attrs): + _TOKENIZER_PREFIX = '_tokenizer_' tokenizer_attrs = {k[len(_TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)} for k in tuple(attrs.keys()): if k.startswith(_TOKENIZER_PREFIX): del attrs[k] return attrs, tokenizer_attrs - - # Special debug flag controled via DEBUG DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and check_bool_env(DEV_DEBUG_VAR, default=False)) # Whether to show the codenge for debug purposes -SHOW_CODEGEN = DEBUG and ( - os.environ.get(DEV_DEBUG_VAR, str(0)).isdigit() and int(os.environ.get(DEV_DEBUG_VAR, str(0))) > 3 -) +SHOW_CODEGEN = DEBUG and os.environ.get(DEV_DEBUG_VAR, str(0)).isdigit() and int(os.environ.get(DEV_DEBUG_VAR, str(0))) > 3 # MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins MYPY = False - - class ExceptionFilter(logging.Filter): def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any): - if exclude_exceptions is None: - exclude_exceptions = [] + if exclude_exceptions is None: exclude_exceptions = [] try: from circus.exc import ConflictError - - if ConflictError not in exclude_exceptions: - exclude_exceptions.append(ConflictError) + if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError) except ImportError: pass super(ExceptionFilter, self).__init__(**kwargs) self.EXCLUDE_EXCEPTIONS = exclude_exceptions - def filter(self, record: logging.LogRecord) -> bool: if record.exc_info: etype, _, _ = record.exc_info if etype is not None: for exc in self.EXCLUDE_EXCEPTIONS: - if issubclass(etype, exc): - return False + if issubclass(etype, exc): return False return True - - class InfoFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - return logging.INFO <= record.levelno < logging.WARNING - - + def filter(self, record: logging.LogRecord) -> bool: return logging.INFO <= record.levelno < logging.WARNING class WarningFilter(logging.Filter): # FIXME: Why does this not work? def filter(self, record: logging.LogRecord) -> bool: - if get_disable_warnings(): - return record.levelno >= logging.ERROR + if get_disable_warnings(): return record.levelno >= logging.ERROR return True - - _LOGGING_CONFIG = { 'version': 1, 'disable_existing_loggers': True, @@ -274,8 +173,6 @@ _LOGGING_CONFIG = { }, 'root': {'level': logging.WARNING}, } - - def configure_logging(): if get_quiet_mode(): _LOGGING_CONFIG['loggers']['openllm']['level'] = logging.ERROR @@ -295,8 +192,6 @@ def configure_logging(): _LOGGING_CONFIG['loggers']['openllm']['level'] = logging.ERROR logging.config.dictConfig(_LOGGING_CONFIG) - - # XXX: define all classes, functions import above this line # since _extras will be the locals() import from this file. _extras = { @@ -332,6 +227,7 @@ __lazy = LazyModule( 'is_ctranslate_available', 'is_transformers_available', 'is_autoawq_available', + 'is_flash_attn_2_available', 'is_bentoml_available', 'is_triton_available', ], diff --git a/openllm-core/src/openllm_core/utils/__init__.pyi b/openllm-core/src/openllm_core/utils/__init__.pyi index 752c68ee..787ee940 100644 --- a/openllm-core/src/openllm_core/utils/__init__.pyi +++ b/openllm-core/src/openllm_core/utils/__init__.pyi @@ -16,6 +16,7 @@ from openllm_core.utils.import_utils import ( is_jupytext_available as is_jupytext_available, is_notebook_available as is_notebook_available, is_peft_available as is_peft_available, + is_flash_attn_2_available as is_flash_attn_2_available, is_torch_available as is_torch_available, is_transformers_available as is_transformers_available, is_vllm_available as is_vllm_available, diff --git a/openllm-core/src/openllm_core/utils/import_utils.py b/openllm-core/src/openllm_core/utils/import_utils.py index 55eb79d5..c6cd37fc 100644 --- a/openllm-core/src/openllm_core/utils/import_utils.py +++ b/openllm-core/src/openllm_core/utils/import_utils.py @@ -15,8 +15,6 @@ OPTIONAL_DEPENDENCIES = { ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'}) USE_VLLM = os.getenv('USE_VLLM', 'AUTO').upper() - - def _is_package_available(package: str) -> bool: _package_available = importlib.util.find_spec(package) is not None if _package_available: @@ -25,8 +23,6 @@ def _is_package_available(package: str) -> bool: except importlib.metadata.PackageNotFoundError: _package_available = False return _package_available - - _ctranslate_available = importlib.util.find_spec('ctranslate2') is not None _vllm_available = importlib.util.find_spec('vllm') is not None _grpc_available = importlib.util.find_spec('grpc') is not None @@ -37,60 +33,24 @@ _transformers_available = _is_package_available('transformers') _bentoml_available = _is_package_available('bentoml') _peft_available = _is_package_available('peft') _bitsandbytes_available = _is_package_available('bitsandbytes') +_flash_attn_available = _is_package_available('flash_attn') _jupyter_available = _is_package_available('jupyter') _jupytext_available = _is_package_available('jupytext') _notebook_available = _is_package_available('notebook') _autogptq_available = _is_package_available('auto_gptq') - - -def is_triton_available() -> bool: - return _triton_available - - -def is_ctranslate_available() -> bool: - return _ctranslate_available - - -def is_bentoml_available() -> bool: - return _bentoml_available # needs this since openllm-core doesn't explicitly depends on bentoml - - -def is_transformers_available() -> bool: - return _transformers_available # needs this since openllm-core doesn't explicitly depends on transformers - - -def is_grpc_available() -> bool: - return _grpc_available - - -def is_jupyter_available() -> bool: - return _jupyter_available - - -def is_jupytext_available() -> bool: - return _jupytext_available - - -def is_notebook_available() -> bool: - return _notebook_available - - -def is_peft_available() -> bool: - return _peft_available - - -def is_bitsandbytes_available() -> bool: - return _bitsandbytes_available - - -def is_autogptq_available() -> bool: - return _autogptq_available - - -def is_torch_available() -> bool: - return _torch_available - - +def is_triton_available() -> bool: return _triton_available +def is_ctranslate_available() -> bool: return _ctranslate_available +def is_bentoml_available() -> bool: return _bentoml_available # needs this since openllm-core doesn't explicitly depends on bentoml +def is_transformers_available() -> bool: return _transformers_available # needs this since openllm-core doesn't explicitly depends on transformers +def is_grpc_available() -> bool: return _grpc_available +def is_jupyter_available() -> bool: return _jupyter_available +def is_jupytext_available() -> bool: return _jupytext_available +def is_notebook_available() -> bool: return _notebook_available +def is_peft_available() -> bool: return _peft_available +def is_bitsandbytes_available() -> bool: return _bitsandbytes_available +def is_autogptq_available() -> bool: return _autogptq_available +def is_torch_available() -> bool: return _torch_available +def is_flash_attn_2_available() -> bool: return _flash_attn_available def is_autoawq_available() -> bool: global _autoawq_available try: @@ -98,8 +58,6 @@ def is_autoawq_available() -> bool: except importlib.metadata.PackageNotFoundError: _autoawq_available = False return _autoawq_available - - def is_vllm_available() -> bool: global _vllm_available if USE_VLLM in ENV_VARS_TRUE_AND_AUTO_VALUES or _vllm_available: diff --git a/openllm-python/src/openllm/_runners.py b/openllm-python/src/openllm/_runners.py index 16c4f6bf..2dfb4ee7 100644 --- a/openllm-python/src/openllm/_runners.py +++ b/openllm-python/src/openllm/_runners.py @@ -83,8 +83,7 @@ class CTranslateRunnable(bentoml.Runnable): SUPPORTS_CPU_MULTI_THREADING = True def __init__(self, llm): - if not is_ctranslate_available(): - raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`') + if not is_ctranslate_available(): raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`') self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer @bentoml.Runnable.method(batchable=False) @@ -127,8 +126,7 @@ class vLLMRunnable(bentoml.Runnable): SUPPORTS_CPU_MULTI_THREADING = True def __init__(self, llm): - if not is_vllm_available(): - raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.') + if not is_vllm_available(): raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.') import vllm self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer @@ -180,13 +178,8 @@ class PyTorchRunnable(bentoml.Runnable): @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): - if adapter_name is not None: - self.model.set_adapter(adapter_name) - async for generation_output in self.forward(prompt_token_ids, request_id, list(stop), **attrs): - yield generation_output.model_dump_json() - - async def forward(self, prompt_token_ids, request_id, stop, **attrs): - from ._generation import get_context_length, is_partial_stop, prepare_logits_processor + from ._generation import get_context_length, prepare_logits_processor + if adapter_name is not None: self.model.set_adapter(adapter_name) max_new_tokens = attrs.pop('max_new_tokens', 256) context_length = attrs.pop('context_length', None) @@ -224,6 +217,8 @@ class PyTorchRunnable(bentoml.Runnable): finish_reason = None prompt_logprobs = [] prompt_token_indices = [] + stopped = False + sample_logprobs: SampleLogprobs = [None] # The first token has no logprobs for i in range(config['max_new_tokens']): if i == 0: # prefill @@ -247,10 +242,9 @@ class PyTorchRunnable(bentoml.Runnable): ) logits = out.logits past_key_values = out.past_key_values - if logits_processor: if config['repetition_penalty'] > 1.0: - tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.model.device) + tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.device) else: tmp_output_ids = None last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] @@ -272,11 +266,11 @@ class PyTorchRunnable(bentoml.Runnable): token = tokens[0] output_token_ids.append(token) - # NOTE: We can't use last_token_logits since logprobs is based on raw logits - logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float) - sample_logprobs: SampleLogprobs = [] - token_logprobs = logprobs[token].item() - cumulative_logprob += token_logprobs + if config['logprobs']: + # NOTE: We can't use last_token_logits since logprobs is based on raw logits + logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float) + token_logprobs = logprobs[token].item() + cumulative_logprob += token_logprobs if config['prompt_logprobs']: for token_id in prompt_token_ids: @@ -296,40 +290,32 @@ class PyTorchRunnable(bentoml.Runnable): clean_up_tokenization_spaces=True, ) - partially_stopped = False if len(stop) > 0: for it in stop: pos = text.rfind(it, rfind_start) if pos != -1: text, stopped = text[:pos], True break - else: - partially_stopped = is_partial_stop(text, it) - if partially_stopped: - break - if config['logprobs']: - sample_logprobs.append({token: token_logprobs}) + if config['logprobs']: sample_logprobs.append({token: token_logprobs}) - if not partially_stopped: - # TODO: calculate prompt_logprobs - yield GenerationOutput( - prompt='', - finished=False, - outputs=[ - CompletionChunk( - index=0, - text=text, - token_ids=tmp_output_ids, - cumulative_logprob=cumulative_logprob, - logprobs=sample_logprobs if config['logprobs'] else None, - finish_reason=None, - ) - ], - prompt_token_ids=prompt_token_ids, - prompt_logprobs=prompt_logprobs, - request_id=request_id, - ) + yield GenerationOutput( + prompt='', + finished=False, + outputs=[ + CompletionChunk( + index=0, + text=text, + token_ids=tmp_output_ids, + cumulative_logprob=cumulative_logprob, + logprobs=sample_logprobs if config['logprobs'] else None, + finish_reason=None, + ) + ], + prompt_token_ids=prompt_token_ids, + prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None, + request_id=request_id, + ).model_dump_json() if stopped: break else: @@ -343,16 +329,16 @@ class PyTorchRunnable(bentoml.Runnable): CompletionChunk( index=0, text=text, - token_ids=output_token_ids[input_len:], + token_ids=output_token_ids, cumulative_logprob=cumulative_logprob, logprobs=sample_logprobs if config['logprobs'] else None, finish_reason=finish_reason, ) ], prompt_token_ids=prompt_token_ids, - prompt_logprobs=prompt_logprobs, + prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None, request_id=request_id, - ) + ).model_dump_json() # Clean del past_key_values, out diff --git a/openllm-python/src/openllm/serialisation/transformers/__init__.py b/openllm-python/src/openllm/serialisation/transformers/__init__.py index 8a26e6e8..f2b40cfb 100644 --- a/openllm-python/src/openllm/serialisation/transformers/__init__.py +++ b/openllm-python/src/openllm/serialisation/transformers/__init__.py @@ -4,7 +4,7 @@ import orjson, torch, transformers, bentoml, openllm from huggingface_hub import snapshot_download from openllm_core.exceptions import OpenLLMException -from openllm_core.utils import first_not_none, is_autogptq_available +from openllm_core.utils import first_not_none, is_autogptq_available, is_flash_attn_2_available from ._helpers import get_tokenizer, infer_autoclass_from_llm, process_config from .weights import HfIgnore @@ -122,18 +122,36 @@ def load_model(llm, *decls, **attrs): if llm.config['model_type'] != 'causal_lm': raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})") - # TODO: investigate load with flash attention - model = auto_class.from_pretrained( - llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs - ) + try: + model = auto_class.from_pretrained( + llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, use_flash_attention_2=is_flash_attn_2_available(), **attrs + ) + except Exception as err: + logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err) + model = auto_class.from_pretrained( + llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs + ) else: - model = auto_class.from_pretrained( - llm.bentomodel.path, - *decls, - config=config, - trust_remote_code=llm.trust_remote_code, - device_map=device_map, - **attrs, - ) + try: + model = auto_class.from_pretrained( + llm.bentomodel.path, + *decls, + config=config, + trust_remote_code=llm.trust_remote_code, + device_map=device_map, + use_flash_attention_2=is_flash_attn_2_available(), + **attrs, + ) + except Exception as err: + logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err) + model = auto_class.from_pretrained( + llm.bentomodel.path, + *decls, + config=config, + trust_remote_code=llm.trust_remote_code, + device_map=device_map, + **attrs, + ) + check_unintialised_params(model) return model diff --git a/openllm-python/src/openllm/utils.pyi b/openllm-python/src/openllm/utils.pyi index 13062d91..25dc267e 100644 --- a/openllm-python/src/openllm/utils.pyi +++ b/openllm-python/src/openllm/utils.pyi @@ -35,6 +35,7 @@ from openllm_core.utils import ( is_bentoml_available as is_bentoml_available, is_bitsandbytes_available as is_bitsandbytes_available, is_ctranslate_available as is_ctranslate_available, + is_flash_attn_2_available as is_flash_attn_2_available, is_grpc_available as is_grpc_available, is_jupyter_available as is_jupyter_available, is_jupytext_available as is_jupytext_available, @@ -55,7 +56,7 @@ from openllm_core.utils import ( ) from openllm_core.utils.serde import converter as converter -from .._llm import LLM +from ._llm import LLM def available_devices() -> Tuple[str, ...]: ... def device_count() -> int: ... diff --git a/openllm-python/src/openllm_cli/entrypoint.py b/openllm-python/src/openllm_cli/entrypoint.py index 8dabd075..6801dde3 100644 --- a/openllm-python/src/openllm_cli/entrypoint.py +++ b/openllm-python/src/openllm_cli/entrypoint.py @@ -1237,7 +1237,7 @@ def prune_command( for b in bentoml.bentos.list() if 'start_name' in b.info.labels and b.info.labels['start_name'] == inflection.underscore(model_name) ] - if model_name is None: + else: available += [ (b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels ] @@ -1342,11 +1342,11 @@ def query_command( if stream: stream_res: t.Iterator[StreamingResponse] = client.generate_stream(prompt, **_memoized) - termui.echo(prompt, fg=input_fg, nl=False if stream else True) + termui.echo(prompt, fg=input_fg, nl=False) for it in stream_res: termui.echo(it.text, fg=generated_fg, nl=False) else: - termui.echo(prompt, fg=input_fg, nl=False) + termui.echo(prompt, fg=input_fg, nl=True) termui.echo(client.generate(prompt, **_memoized).outputs[0].text, fg=generated_fg, nl=False) ctx.exit(0)