mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-05-19 05:57:39 -04:00
chore: cleanup loader (#729)
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -41,6 +41,7 @@ from openllm_core.utils.import_utils import (
|
||||
is_grpc_available as is_grpc_available,
|
||||
is_jupyter_available as is_jupyter_available,
|
||||
is_jupytext_available as is_jupytext_available,
|
||||
is_flash_attn_2_available as is_flash_attn_2_available,
|
||||
is_notebook_available as is_notebook_available,
|
||||
is_peft_available as is_peft_available,
|
||||
is_torch_available as is_torch_available,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import functools, hashlib, logging, logging.config, os, sys, types, uuid, typing as t
|
||||
from pathlib import Path as _Path
|
||||
|
||||
from . import import_utils as iutils, pkg
|
||||
from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES
|
||||
from .lazy import LazyLoader as LazyLoader, LazyModule as LazyModule, VersionInfo as VersionInfo
|
||||
@@ -16,107 +15,60 @@ DEV_DEBUG_VAR = 'DEBUG'
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _WithArgsTypes() -> tuple[type[t.Any], ...]:
|
||||
try:
|
||||
from typing import GenericAlias as _TypingGenericAlias # type: ignore
|
||||
from typing import GenericAlias as _TypingGenericAlias
|
||||
except ImportError:
|
||||
_TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
_TypingGenericAlias = () # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
# _GenericAlias is the actual GenericAlias implementation
|
||||
return (
|
||||
(_TypingGenericAlias,) if sys.version_info < (3, 10) else (t._GenericAlias, types.GenericAlias, types.UnionType)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
return (_TypingGenericAlias,) if sys.version_info < (3, 10) else (t._GenericAlias, types.GenericAlias, types.UnionType)
|
||||
def lenient_issubclass(cls, class_or_tuple):
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
|
||||
except TypeError:
|
||||
if isinstance(cls, _WithArgsTypes()):
|
||||
return False
|
||||
if isinstance(cls, _WithArgsTypes()): return False
|
||||
raise
|
||||
|
||||
|
||||
def resolve_user_filepath(filepath, ctx):
|
||||
_path = os.path.expanduser(os.path.expandvars(filepath))
|
||||
if os.path.exists(_path):
|
||||
return os.path.realpath(_path)
|
||||
if os.path.exists(_path): return os.path.realpath(_path)
|
||||
# Try finding file in ctx if provided
|
||||
if ctx:
|
||||
_path = os.path.expanduser(os.path.join(ctx, filepath))
|
||||
if os.path.exists(_path):
|
||||
return os.path.realpath(_path)
|
||||
if os.path.exists(_path): return os.path.realpath(_path)
|
||||
raise FileNotFoundError(f'file {filepath} not found')
|
||||
|
||||
|
||||
# this is the supress version of resolve_user_filepath
|
||||
def resolve_filepath(path, ctx=None):
|
||||
try:
|
||||
return resolve_user_filepath(path, ctx)
|
||||
except FileNotFoundError:
|
||||
return path
|
||||
|
||||
|
||||
def check_bool_env(env, default=True):
|
||||
v = os.getenv(env, default=str(default)).upper()
|
||||
if v.isdigit():
|
||||
return bool(int(v)) # special check for digits
|
||||
if v.isdigit(): return bool(int(v)) # special check for digits
|
||||
return v in ENV_VARS_TRUE_VALUES
|
||||
|
||||
|
||||
def calc_dir_size(path):
|
||||
return sum(f.stat().st_size for f in _Path(path).glob('**/*') if f.is_file())
|
||||
|
||||
|
||||
def calc_dir_size(path): return sum(f.stat().st_size for f in _Path(path).glob('**/*') if f.is_file())
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def generate_hash_from_file(f, algorithm='sha1'):
|
||||
return str(getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest())
|
||||
|
||||
|
||||
def generate_hash_from_file(f, algorithm='sha1'): return str(getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest())
|
||||
def getenv(env, default=None, var=None):
|
||||
env_key = {env.upper(), f'OPENLLM_{env.upper()}'}
|
||||
if var is not None:
|
||||
env_key = set(var) | env_key
|
||||
|
||||
if var is not None: env_key = set(var) | env_key
|
||||
def callback(k: str) -> t.Any:
|
||||
_var = os.getenv(k)
|
||||
if _var and k.startswith('OPENLLM_'):
|
||||
logger.warning("Using '%s' environment is deprecated, use '%s' instead.", k.upper(), k[8:].upper())
|
||||
if _var and k.startswith('OPENLLM_'): logger.warning("Using '%s' environment is deprecated, use '%s' instead.", k.upper(), k[8:].upper())
|
||||
return _var
|
||||
|
||||
return first_not_none(*(callback(k) for k in env_key), default=default)
|
||||
|
||||
|
||||
def field_env_key(key, suffix=None):
|
||||
return '_'.join(filter(None, map(str.upper, ['OPENLLM', suffix.strip('_') if suffix else '', key])))
|
||||
|
||||
|
||||
def get_debug_mode():
|
||||
return check_bool_env(DEBUG_ENV_VAR, False) if (not DEBUG and DEBUG_ENV_VAR in os.environ) else DEBUG
|
||||
|
||||
|
||||
def field_env_key(key, suffix=None): return '_'.join(filter(None, map(str.upper, ['OPENLLM', suffix.strip('_') if suffix else '', key])))
|
||||
def get_debug_mode(): return check_bool_env(DEBUG_ENV_VAR, False) if (not DEBUG and DEBUG_ENV_VAR in os.environ) else DEBUG
|
||||
def get_quiet_mode():
|
||||
if QUIET_ENV_VAR in os.environ:
|
||||
return check_bool_env(QUIET_ENV_VAR, False)
|
||||
if DEBUG:
|
||||
return False
|
||||
if QUIET_ENV_VAR in os.environ: return check_bool_env(QUIET_ENV_VAR, False)
|
||||
if DEBUG: return False
|
||||
return False
|
||||
|
||||
|
||||
def get_disable_warnings():
|
||||
return check_bool_env(WARNING_ENV_VAR, False)
|
||||
|
||||
|
||||
def get_disable_warnings(): return check_bool_env(WARNING_ENV_VAR, False)
|
||||
def set_disable_warnings(disable=True):
|
||||
if disable:
|
||||
os.environ[WARNING_ENV_VAR] = str(disable)
|
||||
|
||||
|
||||
if disable: os.environ[WARNING_ENV_VAR] = str(disable)
|
||||
def set_debug_mode(enabled, level=1):
|
||||
if enabled:
|
||||
os.environ[DEV_DEBUG_VAR] = str(level)
|
||||
if enabled: os.environ[DEV_DEBUG_VAR] = str(level)
|
||||
os.environ.update(
|
||||
{
|
||||
DEBUG_ENV_VAR: str(enabled),
|
||||
@@ -126,132 +78,79 @@ def set_debug_mode(enabled, level=1):
|
||||
}
|
||||
)
|
||||
set_disable_warnings(not enabled)
|
||||
|
||||
|
||||
def set_quiet_mode(enabled):
|
||||
os.environ.update(
|
||||
{QUIET_ENV_VAR: str(enabled), DEBUG_ENV_VAR: str(not enabled), _GRPC_DEBUG_ENV_VAR: 'NONE', 'CT2_VERBOSE': '-1'}
|
||||
)
|
||||
set_disable_warnings(enabled)
|
||||
|
||||
|
||||
def gen_random_uuid(prefix: str | None = None) -> str:
|
||||
return '-'.join([prefix or 'openllm', str(uuid.uuid4().hex)])
|
||||
|
||||
|
||||
def gen_random_uuid(prefix: str | None = None) -> str: return '-'.join([prefix or 'openllm', str(uuid.uuid4().hex)])
|
||||
# NOTE: `compose` any number of unary functions into a single unary function
|
||||
# compose(f, g, h)(x) == f(g(h(x))); compose(f, g, h)(x, y, z) == f(g(h(x, y, z)))
|
||||
def compose(*funcs):
|
||||
return functools.reduce(lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)), funcs)
|
||||
|
||||
|
||||
def compose(*funcs): return functools.reduce(lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)), funcs)
|
||||
# NOTE: `apply` a transform function that is invoked on results returned from the decorated function
|
||||
# apply(reversed)(func)(*args, **kwargs) == reversed(func(*args, **kwargs))
|
||||
def apply(transform):
|
||||
return lambda func: functools.wraps(func)(compose(transform, func))
|
||||
|
||||
|
||||
def validate_is_path(maybe_path):
|
||||
return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
|
||||
|
||||
|
||||
def first_not_none(*args, default=None):
|
||||
return next((arg for arg in args if arg is not None), default)
|
||||
|
||||
|
||||
def apply(transform): return lambda func: functools.wraps(func)(compose(transform, func))
|
||||
def validate_is_path(maybe_path): return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
|
||||
def first_not_none(*args, default=None): return next((arg for arg in args if arg is not None), default)
|
||||
def generate_context(framework_name):
|
||||
from bentoml._internal.models.model import ModelContext
|
||||
|
||||
framework_versions = {
|
||||
'transformers': pkg.get_pkg_version('transformers'),
|
||||
'safetensors': pkg.get_pkg_version('safetensors'),
|
||||
'optimum': pkg.get_pkg_version('optimum'),
|
||||
'accelerate': pkg.get_pkg_version('accelerate'),
|
||||
}
|
||||
if iutils.is_torch_available():
|
||||
framework_versions['torch'] = pkg.get_pkg_version('torch')
|
||||
if iutils.is_ctranslate_available():
|
||||
framework_versions['ctranslate2'] = pkg.get_pkg_version('ctranslate2')
|
||||
if iutils.is_vllm_available():
|
||||
framework_versions['vllm'] = pkg.get_pkg_version('vllm')
|
||||
if iutils.is_autoawq_available():
|
||||
framework_versions['autoawq'] = pkg.get_pkg_version('autoawq')
|
||||
if iutils.is_autogptq_available():
|
||||
framework_versions['autogptq'] = pkg.get_pkg_version('auto_gptq')
|
||||
if iutils.is_bentoml_available():
|
||||
framework_versions['bentoml'] = pkg.get_pkg_version('bentoml')
|
||||
if iutils.is_triton_available():
|
||||
framework_versions['triton'] = pkg.get_pkg_version('triton')
|
||||
if iutils.is_torch_available(): framework_versions['torch'] = pkg.get_pkg_version('torch')
|
||||
if iutils.is_ctranslate_available(): framework_versions['ctranslate2'] = pkg.get_pkg_version('ctranslate2')
|
||||
if iutils.is_vllm_available(): framework_versions['vllm'] = pkg.get_pkg_version('vllm')
|
||||
if iutils.is_autoawq_available(): framework_versions['autoawq'] = pkg.get_pkg_version('autoawq')
|
||||
if iutils.is_autogptq_available(): framework_versions['autogptq'] = pkg.get_pkg_version('auto_gptq')
|
||||
if iutils.is_bentoml_available(): framework_versions['bentoml'] = pkg.get_pkg_version('bentoml')
|
||||
if iutils.is_triton_available(): framework_versions['triton'] = pkg.get_pkg_version('triton')
|
||||
if iutils.is_flash_attn_2_available(): framework_versions['flash_attn'] = pkg.get_pkg_version('flash_attn')
|
||||
return ModelContext(framework_name=framework_name, framework_versions=framework_versions)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def in_notebook():
|
||||
try:
|
||||
from IPython.core.getipython import get_ipython
|
||||
|
||||
return 'IPKernelApp' in get_ipython().config
|
||||
from IPython.core.getipython import get_ipython; return 'IPKernelApp' in get_ipython().config
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
_TOKENIZER_PREFIX = '_tokenizer_'
|
||||
|
||||
|
||||
def flatten_attrs(**attrs):
|
||||
_TOKENIZER_PREFIX = '_tokenizer_'
|
||||
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
|
||||
for k in tuple(attrs.keys()):
|
||||
if k.startswith(_TOKENIZER_PREFIX):
|
||||
del attrs[k]
|
||||
return attrs, tokenizer_attrs
|
||||
|
||||
|
||||
# Special debug flag controled via DEBUG
|
||||
DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and check_bool_env(DEV_DEBUG_VAR, default=False))
|
||||
# Whether to show the codenge for debug purposes
|
||||
SHOW_CODEGEN = DEBUG and (
|
||||
os.environ.get(DEV_DEBUG_VAR, str(0)).isdigit() and int(os.environ.get(DEV_DEBUG_VAR, str(0))) > 3
|
||||
)
|
||||
SHOW_CODEGEN = DEBUG and os.environ.get(DEV_DEBUG_VAR, str(0)).isdigit() and int(os.environ.get(DEV_DEBUG_VAR, str(0))) > 3
|
||||
# MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins
|
||||
MYPY = False
|
||||
|
||||
|
||||
class ExceptionFilter(logging.Filter):
|
||||
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
|
||||
if exclude_exceptions is None:
|
||||
exclude_exceptions = []
|
||||
if exclude_exceptions is None: exclude_exceptions = []
|
||||
try:
|
||||
from circus.exc import ConflictError
|
||||
|
||||
if ConflictError not in exclude_exceptions:
|
||||
exclude_exceptions.append(ConflictError)
|
||||
if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError)
|
||||
except ImportError:
|
||||
pass
|
||||
super(ExceptionFilter, self).__init__(**kwargs)
|
||||
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.exc_info:
|
||||
etype, _, _ = record.exc_info
|
||||
if etype is not None:
|
||||
for exc in self.EXCLUDE_EXCEPTIONS:
|
||||
if issubclass(etype, exc):
|
||||
return False
|
||||
if issubclass(etype, exc): return False
|
||||
return True
|
||||
|
||||
|
||||
class InfoFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return logging.INFO <= record.levelno < logging.WARNING
|
||||
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool: return logging.INFO <= record.levelno < logging.WARNING
|
||||
class WarningFilter(logging.Filter): # FIXME: Why does this not work?
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if get_disable_warnings():
|
||||
return record.levelno >= logging.ERROR
|
||||
if get_disable_warnings(): return record.levelno >= logging.ERROR
|
||||
return True
|
||||
|
||||
|
||||
_LOGGING_CONFIG = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': True,
|
||||
@@ -274,8 +173,6 @@ _LOGGING_CONFIG = {
|
||||
},
|
||||
'root': {'level': logging.WARNING},
|
||||
}
|
||||
|
||||
|
||||
def configure_logging():
|
||||
if get_quiet_mode():
|
||||
_LOGGING_CONFIG['loggers']['openllm']['level'] = logging.ERROR
|
||||
@@ -295,8 +192,6 @@ def configure_logging():
|
||||
_LOGGING_CONFIG['loggers']['openllm']['level'] = logging.ERROR
|
||||
|
||||
logging.config.dictConfig(_LOGGING_CONFIG)
|
||||
|
||||
|
||||
# XXX: define all classes, functions import above this line
|
||||
# since _extras will be the locals() import from this file.
|
||||
_extras = {
|
||||
@@ -332,6 +227,7 @@ __lazy = LazyModule(
|
||||
'is_ctranslate_available',
|
||||
'is_transformers_available',
|
||||
'is_autoawq_available',
|
||||
'is_flash_attn_2_available',
|
||||
'is_bentoml_available',
|
||||
'is_triton_available',
|
||||
],
|
||||
|
||||
@@ -16,6 +16,7 @@ from openllm_core.utils.import_utils import (
|
||||
is_jupytext_available as is_jupytext_available,
|
||||
is_notebook_available as is_notebook_available,
|
||||
is_peft_available as is_peft_available,
|
||||
is_flash_attn_2_available as is_flash_attn_2_available,
|
||||
is_torch_available as is_torch_available,
|
||||
is_transformers_available as is_transformers_available,
|
||||
is_vllm_available as is_vllm_available,
|
||||
|
||||
@@ -15,8 +15,6 @@ OPTIONAL_DEPENDENCIES = {
|
||||
ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
|
||||
USE_VLLM = os.getenv('USE_VLLM', 'AUTO').upper()
|
||||
|
||||
|
||||
def _is_package_available(package: str) -> bool:
|
||||
_package_available = importlib.util.find_spec(package) is not None
|
||||
if _package_available:
|
||||
@@ -25,8 +23,6 @@ def _is_package_available(package: str) -> bool:
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_package_available = False
|
||||
return _package_available
|
||||
|
||||
|
||||
_ctranslate_available = importlib.util.find_spec('ctranslate2') is not None
|
||||
_vllm_available = importlib.util.find_spec('vllm') is not None
|
||||
_grpc_available = importlib.util.find_spec('grpc') is not None
|
||||
@@ -37,60 +33,24 @@ _transformers_available = _is_package_available('transformers')
|
||||
_bentoml_available = _is_package_available('bentoml')
|
||||
_peft_available = _is_package_available('peft')
|
||||
_bitsandbytes_available = _is_package_available('bitsandbytes')
|
||||
_flash_attn_available = _is_package_available('flash_attn')
|
||||
_jupyter_available = _is_package_available('jupyter')
|
||||
_jupytext_available = _is_package_available('jupytext')
|
||||
_notebook_available = _is_package_available('notebook')
|
||||
_autogptq_available = _is_package_available('auto_gptq')
|
||||
|
||||
|
||||
def is_triton_available() -> bool:
|
||||
return _triton_available
|
||||
|
||||
|
||||
def is_ctranslate_available() -> bool:
|
||||
return _ctranslate_available
|
||||
|
||||
|
||||
def is_bentoml_available() -> bool:
|
||||
return _bentoml_available # needs this since openllm-core doesn't explicitly depends on bentoml
|
||||
|
||||
|
||||
def is_transformers_available() -> bool:
|
||||
return _transformers_available # needs this since openllm-core doesn't explicitly depends on transformers
|
||||
|
||||
|
||||
def is_grpc_available() -> bool:
|
||||
return _grpc_available
|
||||
|
||||
|
||||
def is_jupyter_available() -> bool:
|
||||
return _jupyter_available
|
||||
|
||||
|
||||
def is_jupytext_available() -> bool:
|
||||
return _jupytext_available
|
||||
|
||||
|
||||
def is_notebook_available() -> bool:
|
||||
return _notebook_available
|
||||
|
||||
|
||||
def is_peft_available() -> bool:
|
||||
return _peft_available
|
||||
|
||||
|
||||
def is_bitsandbytes_available() -> bool:
|
||||
return _bitsandbytes_available
|
||||
|
||||
|
||||
def is_autogptq_available() -> bool:
|
||||
return _autogptq_available
|
||||
|
||||
|
||||
def is_torch_available() -> bool:
|
||||
return _torch_available
|
||||
|
||||
|
||||
def is_triton_available() -> bool: return _triton_available
|
||||
def is_ctranslate_available() -> bool: return _ctranslate_available
|
||||
def is_bentoml_available() -> bool: return _bentoml_available # needs this since openllm-core doesn't explicitly depends on bentoml
|
||||
def is_transformers_available() -> bool: return _transformers_available # needs this since openllm-core doesn't explicitly depends on transformers
|
||||
def is_grpc_available() -> bool: return _grpc_available
|
||||
def is_jupyter_available() -> bool: return _jupyter_available
|
||||
def is_jupytext_available() -> bool: return _jupytext_available
|
||||
def is_notebook_available() -> bool: return _notebook_available
|
||||
def is_peft_available() -> bool: return _peft_available
|
||||
def is_bitsandbytes_available() -> bool: return _bitsandbytes_available
|
||||
def is_autogptq_available() -> bool: return _autogptq_available
|
||||
def is_torch_available() -> bool: return _torch_available
|
||||
def is_flash_attn_2_available() -> bool: return _flash_attn_available
|
||||
def is_autoawq_available() -> bool:
|
||||
global _autoawq_available
|
||||
try:
|
||||
@@ -98,8 +58,6 @@ def is_autoawq_available() -> bool:
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_autoawq_available = False
|
||||
return _autoawq_available
|
||||
|
||||
|
||||
def is_vllm_available() -> bool:
|
||||
global _vllm_available
|
||||
if USE_VLLM in ENV_VARS_TRUE_AND_AUTO_VALUES or _vllm_available:
|
||||
|
||||
@@ -83,8 +83,7 @@ class CTranslateRunnable(bentoml.Runnable):
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm):
|
||||
if not is_ctranslate_available():
|
||||
raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
|
||||
if not is_ctranslate_available(): raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
@@ -127,8 +126,7 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm):
|
||||
if not is_vllm_available():
|
||||
raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
|
||||
if not is_vllm_available(): raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
|
||||
import vllm
|
||||
|
||||
self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer
|
||||
@@ -180,13 +178,8 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
if adapter_name is not None:
|
||||
self.model.set_adapter(adapter_name)
|
||||
async for generation_output in self.forward(prompt_token_ids, request_id, list(stop), **attrs):
|
||||
yield generation_output.model_dump_json()
|
||||
|
||||
async def forward(self, prompt_token_ids, request_id, stop, **attrs):
|
||||
from ._generation import get_context_length, is_partial_stop, prepare_logits_processor
|
||||
from ._generation import get_context_length, prepare_logits_processor
|
||||
if adapter_name is not None: self.model.set_adapter(adapter_name)
|
||||
|
||||
max_new_tokens = attrs.pop('max_new_tokens', 256)
|
||||
context_length = attrs.pop('context_length', None)
|
||||
@@ -224,6 +217,8 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
finish_reason = None
|
||||
prompt_logprobs = []
|
||||
prompt_token_indices = []
|
||||
stopped = False
|
||||
sample_logprobs: SampleLogprobs = [None] # The first token has no logprobs
|
||||
|
||||
for i in range(config['max_new_tokens']):
|
||||
if i == 0: # prefill
|
||||
@@ -247,10 +242,9 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
if logits_processor:
|
||||
if config['repetition_penalty'] > 1.0:
|
||||
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.model.device)
|
||||
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.device)
|
||||
else:
|
||||
tmp_output_ids = None
|
||||
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
||||
@@ -272,11 +266,11 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
|
||||
token = tokens[0]
|
||||
output_token_ids.append(token)
|
||||
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
|
||||
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
|
||||
sample_logprobs: SampleLogprobs = []
|
||||
token_logprobs = logprobs[token].item()
|
||||
cumulative_logprob += token_logprobs
|
||||
if config['logprobs']:
|
||||
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
|
||||
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
|
||||
token_logprobs = logprobs[token].item()
|
||||
cumulative_logprob += token_logprobs
|
||||
|
||||
if config['prompt_logprobs']:
|
||||
for token_id in prompt_token_ids:
|
||||
@@ -296,40 +290,32 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
partially_stopped = False
|
||||
if len(stop) > 0:
|
||||
for it in stop:
|
||||
pos = text.rfind(it, rfind_start)
|
||||
if pos != -1:
|
||||
text, stopped = text[:pos], True
|
||||
break
|
||||
else:
|
||||
partially_stopped = is_partial_stop(text, it)
|
||||
if partially_stopped:
|
||||
break
|
||||
|
||||
if config['logprobs']:
|
||||
sample_logprobs.append({token: token_logprobs})
|
||||
if config['logprobs']: sample_logprobs.append({token: token_logprobs})
|
||||
|
||||
if not partially_stopped:
|
||||
# TODO: calculate prompt_logprobs
|
||||
yield GenerationOutput(
|
||||
prompt='',
|
||||
finished=False,
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=tmp_output_ids,
|
||||
cumulative_logprob=cumulative_logprob,
|
||||
logprobs=sample_logprobs if config['logprobs'] else None,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
request_id=request_id,
|
||||
)
|
||||
yield GenerationOutput(
|
||||
prompt='',
|
||||
finished=False,
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=tmp_output_ids,
|
||||
cumulative_logprob=cumulative_logprob,
|
||||
logprobs=sample_logprobs if config['logprobs'] else None,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
|
||||
request_id=request_id,
|
||||
).model_dump_json()
|
||||
if stopped:
|
||||
break
|
||||
else:
|
||||
@@ -343,16 +329,16 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
CompletionChunk(
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=output_token_ids[input_len:],
|
||||
token_ids=output_token_ids,
|
||||
cumulative_logprob=cumulative_logprob,
|
||||
logprobs=sample_logprobs if config['logprobs'] else None,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
|
||||
request_id=request_id,
|
||||
)
|
||||
).model_dump_json()
|
||||
|
||||
# Clean
|
||||
del past_key_values, out
|
||||
|
||||
@@ -4,7 +4,7 @@ import orjson, torch, transformers, bentoml, openllm
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import first_not_none, is_autogptq_available
|
||||
from openllm_core.utils import first_not_none, is_autogptq_available, is_flash_attn_2_available
|
||||
|
||||
from ._helpers import get_tokenizer, infer_autoclass_from_llm, process_config
|
||||
from .weights import HfIgnore
|
||||
@@ -122,18 +122,36 @@ def load_model(llm, *decls, **attrs):
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
|
||||
# TODO: investigate load with flash attention
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
try:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, use_flash_attention_2=is_flash_attn_2_available(), **attrs
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
else:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
try:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
use_flash_attention_2=is_flash_attn_2_available(),
|
||||
**attrs,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
|
||||
check_unintialised_params(model)
|
||||
return model
|
||||
|
||||
@@ -35,6 +35,7 @@ from openllm_core.utils import (
|
||||
is_bentoml_available as is_bentoml_available,
|
||||
is_bitsandbytes_available as is_bitsandbytes_available,
|
||||
is_ctranslate_available as is_ctranslate_available,
|
||||
is_flash_attn_2_available as is_flash_attn_2_available,
|
||||
is_grpc_available as is_grpc_available,
|
||||
is_jupyter_available as is_jupyter_available,
|
||||
is_jupytext_available as is_jupytext_available,
|
||||
@@ -55,7 +56,7 @@ from openllm_core.utils import (
|
||||
)
|
||||
from openllm_core.utils.serde import converter as converter
|
||||
|
||||
from .._llm import LLM
|
||||
from ._llm import LLM
|
||||
|
||||
def available_devices() -> Tuple[str, ...]: ...
|
||||
def device_count() -> int: ...
|
||||
|
||||
@@ -1237,7 +1237,7 @@ def prune_command(
|
||||
for b in bentoml.bentos.list()
|
||||
if 'start_name' in b.info.labels and b.info.labels['start_name'] == inflection.underscore(model_name)
|
||||
]
|
||||
if model_name is None:
|
||||
else:
|
||||
available += [
|
||||
(b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels
|
||||
]
|
||||
@@ -1342,11 +1342,11 @@ def query_command(
|
||||
|
||||
if stream:
|
||||
stream_res: t.Iterator[StreamingResponse] = client.generate_stream(prompt, **_memoized)
|
||||
termui.echo(prompt, fg=input_fg, nl=False if stream else True)
|
||||
termui.echo(prompt, fg=input_fg, nl=False)
|
||||
for it in stream_res:
|
||||
termui.echo(it.text, fg=generated_fg, nl=False)
|
||||
else:
|
||||
termui.echo(prompt, fg=input_fg, nl=False)
|
||||
termui.echo(prompt, fg=input_fg, nl=True)
|
||||
termui.echo(client.generate(prompt, **_memoized).outputs[0].text, fg=generated_fg, nl=False)
|
||||
ctx.exit(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user