chore: cleanup loader (#729)

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-22 21:51:51 -05:00
committed by GitHub
parent 5442d9cd10
commit 52a44b1bfa
8 changed files with 125 additions and 264 deletions

View File

@@ -41,6 +41,7 @@ from openllm_core.utils.import_utils import (
is_grpc_available as is_grpc_available,
is_jupyter_available as is_jupyter_available,
is_jupytext_available as is_jupytext_available,
is_flash_attn_2_available as is_flash_attn_2_available,
is_notebook_available as is_notebook_available,
is_peft_available as is_peft_available,
is_torch_available as is_torch_available,

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import functools, hashlib, logging, logging.config, os, sys, types, uuid, typing as t
from pathlib import Path as _Path
from . import import_utils as iutils, pkg
from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES
from .lazy import LazyLoader as LazyLoader, LazyModule as LazyModule, VersionInfo as VersionInfo
@@ -16,107 +15,60 @@ DEV_DEBUG_VAR = 'DEBUG'
# equivocal setattr to save one lookup per assignment
_object_setattr = object.__setattr__
logger = logging.getLogger(__name__)
@functools.lru_cache(maxsize=1)
def _WithArgsTypes() -> tuple[type[t.Any], ...]:
try:
from typing import GenericAlias as _TypingGenericAlias # type: ignore
from typing import GenericAlias as _TypingGenericAlias
except ImportError:
_TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
_TypingGenericAlias = () # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
# _GenericAlias is the actual GenericAlias implementation
return (
(_TypingGenericAlias,) if sys.version_info < (3, 10) else (t._GenericAlias, types.GenericAlias, types.UnionType)
) # type: ignore
return (_TypingGenericAlias,) if sys.version_info < (3, 10) else (t._GenericAlias, types.GenericAlias, types.UnionType)
def lenient_issubclass(cls, class_or_tuple):
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
except TypeError:
if isinstance(cls, _WithArgsTypes()):
return False
if isinstance(cls, _WithArgsTypes()): return False
raise
def resolve_user_filepath(filepath, ctx):
_path = os.path.expanduser(os.path.expandvars(filepath))
if os.path.exists(_path):
return os.path.realpath(_path)
if os.path.exists(_path): return os.path.realpath(_path)
# Try finding file in ctx if provided
if ctx:
_path = os.path.expanduser(os.path.join(ctx, filepath))
if os.path.exists(_path):
return os.path.realpath(_path)
if os.path.exists(_path): return os.path.realpath(_path)
raise FileNotFoundError(f'file {filepath} not found')
# this is the supress version of resolve_user_filepath
def resolve_filepath(path, ctx=None):
try:
return resolve_user_filepath(path, ctx)
except FileNotFoundError:
return path
def check_bool_env(env, default=True):
v = os.getenv(env, default=str(default)).upper()
if v.isdigit():
return bool(int(v)) # special check for digits
if v.isdigit(): return bool(int(v)) # special check for digits
return v in ENV_VARS_TRUE_VALUES
def calc_dir_size(path):
return sum(f.stat().st_size for f in _Path(path).glob('**/*') if f.is_file())
def calc_dir_size(path): return sum(f.stat().st_size for f in _Path(path).glob('**/*') if f.is_file())
@functools.lru_cache(maxsize=128)
def generate_hash_from_file(f, algorithm='sha1'):
return str(getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest())
def generate_hash_from_file(f, algorithm='sha1'): return str(getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest())
def getenv(env, default=None, var=None):
env_key = {env.upper(), f'OPENLLM_{env.upper()}'}
if var is not None:
env_key = set(var) | env_key
if var is not None: env_key = set(var) | env_key
def callback(k: str) -> t.Any:
_var = os.getenv(k)
if _var and k.startswith('OPENLLM_'):
logger.warning("Using '%s' environment is deprecated, use '%s' instead.", k.upper(), k[8:].upper())
if _var and k.startswith('OPENLLM_'): logger.warning("Using '%s' environment is deprecated, use '%s' instead.", k.upper(), k[8:].upper())
return _var
return first_not_none(*(callback(k) for k in env_key), default=default)
def field_env_key(key, suffix=None):
return '_'.join(filter(None, map(str.upper, ['OPENLLM', suffix.strip('_') if suffix else '', key])))
def get_debug_mode():
return check_bool_env(DEBUG_ENV_VAR, False) if (not DEBUG and DEBUG_ENV_VAR in os.environ) else DEBUG
def field_env_key(key, suffix=None): return '_'.join(filter(None, map(str.upper, ['OPENLLM', suffix.strip('_') if suffix else '', key])))
def get_debug_mode(): return check_bool_env(DEBUG_ENV_VAR, False) if (not DEBUG and DEBUG_ENV_VAR in os.environ) else DEBUG
def get_quiet_mode():
if QUIET_ENV_VAR in os.environ:
return check_bool_env(QUIET_ENV_VAR, False)
if DEBUG:
return False
if QUIET_ENV_VAR in os.environ: return check_bool_env(QUIET_ENV_VAR, False)
if DEBUG: return False
return False
def get_disable_warnings():
return check_bool_env(WARNING_ENV_VAR, False)
def get_disable_warnings(): return check_bool_env(WARNING_ENV_VAR, False)
def set_disable_warnings(disable=True):
if disable:
os.environ[WARNING_ENV_VAR] = str(disable)
if disable: os.environ[WARNING_ENV_VAR] = str(disable)
def set_debug_mode(enabled, level=1):
if enabled:
os.environ[DEV_DEBUG_VAR] = str(level)
if enabled: os.environ[DEV_DEBUG_VAR] = str(level)
os.environ.update(
{
DEBUG_ENV_VAR: str(enabled),
@@ -126,132 +78,79 @@ def set_debug_mode(enabled, level=1):
}
)
set_disable_warnings(not enabled)
def set_quiet_mode(enabled):
os.environ.update(
{QUIET_ENV_VAR: str(enabled), DEBUG_ENV_VAR: str(not enabled), _GRPC_DEBUG_ENV_VAR: 'NONE', 'CT2_VERBOSE': '-1'}
)
set_disable_warnings(enabled)
def gen_random_uuid(prefix: str | None = None) -> str:
return '-'.join([prefix or 'openllm', str(uuid.uuid4().hex)])
def gen_random_uuid(prefix: str | None = None) -> str: return '-'.join([prefix or 'openllm', str(uuid.uuid4().hex)])
# NOTE: `compose` any number of unary functions into a single unary function
# compose(f, g, h)(x) == f(g(h(x))); compose(f, g, h)(x, y, z) == f(g(h(x, y, z)))
def compose(*funcs):
return functools.reduce(lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)), funcs)
def compose(*funcs): return functools.reduce(lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)), funcs)
# NOTE: `apply` a transform function that is invoked on results returned from the decorated function
# apply(reversed)(func)(*args, **kwargs) == reversed(func(*args, **kwargs))
def apply(transform):
return lambda func: functools.wraps(func)(compose(transform, func))
def validate_is_path(maybe_path):
return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
def first_not_none(*args, default=None):
return next((arg for arg in args if arg is not None), default)
def apply(transform): return lambda func: functools.wraps(func)(compose(transform, func))
def validate_is_path(maybe_path): return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
def first_not_none(*args, default=None): return next((arg for arg in args if arg is not None), default)
def generate_context(framework_name):
from bentoml._internal.models.model import ModelContext
framework_versions = {
'transformers': pkg.get_pkg_version('transformers'),
'safetensors': pkg.get_pkg_version('safetensors'),
'optimum': pkg.get_pkg_version('optimum'),
'accelerate': pkg.get_pkg_version('accelerate'),
}
if iutils.is_torch_available():
framework_versions['torch'] = pkg.get_pkg_version('torch')
if iutils.is_ctranslate_available():
framework_versions['ctranslate2'] = pkg.get_pkg_version('ctranslate2')
if iutils.is_vllm_available():
framework_versions['vllm'] = pkg.get_pkg_version('vllm')
if iutils.is_autoawq_available():
framework_versions['autoawq'] = pkg.get_pkg_version('autoawq')
if iutils.is_autogptq_available():
framework_versions['autogptq'] = pkg.get_pkg_version('auto_gptq')
if iutils.is_bentoml_available():
framework_versions['bentoml'] = pkg.get_pkg_version('bentoml')
if iutils.is_triton_available():
framework_versions['triton'] = pkg.get_pkg_version('triton')
if iutils.is_torch_available(): framework_versions['torch'] = pkg.get_pkg_version('torch')
if iutils.is_ctranslate_available(): framework_versions['ctranslate2'] = pkg.get_pkg_version('ctranslate2')
if iutils.is_vllm_available(): framework_versions['vllm'] = pkg.get_pkg_version('vllm')
if iutils.is_autoawq_available(): framework_versions['autoawq'] = pkg.get_pkg_version('autoawq')
if iutils.is_autogptq_available(): framework_versions['autogptq'] = pkg.get_pkg_version('auto_gptq')
if iutils.is_bentoml_available(): framework_versions['bentoml'] = pkg.get_pkg_version('bentoml')
if iutils.is_triton_available(): framework_versions['triton'] = pkg.get_pkg_version('triton')
if iutils.is_flash_attn_2_available(): framework_versions['flash_attn'] = pkg.get_pkg_version('flash_attn')
return ModelContext(framework_name=framework_name, framework_versions=framework_versions)
@functools.lru_cache(maxsize=1)
def in_notebook():
try:
from IPython.core.getipython import get_ipython
return 'IPKernelApp' in get_ipython().config
from IPython.core.getipython import get_ipython; return 'IPKernelApp' in get_ipython().config
except Exception:
return False
_TOKENIZER_PREFIX = '_tokenizer_'
def flatten_attrs(**attrs):
_TOKENIZER_PREFIX = '_tokenizer_'
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
for k in tuple(attrs.keys()):
if k.startswith(_TOKENIZER_PREFIX):
del attrs[k]
return attrs, tokenizer_attrs
# Special debug flag controled via DEBUG
DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and check_bool_env(DEV_DEBUG_VAR, default=False))
# Whether to show the codenge for debug purposes
SHOW_CODEGEN = DEBUG and (
os.environ.get(DEV_DEBUG_VAR, str(0)).isdigit() and int(os.environ.get(DEV_DEBUG_VAR, str(0))) > 3
)
SHOW_CODEGEN = DEBUG and os.environ.get(DEV_DEBUG_VAR, str(0)).isdigit() and int(os.environ.get(DEV_DEBUG_VAR, str(0))) > 3
# MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins
MYPY = False
class ExceptionFilter(logging.Filter):
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
if exclude_exceptions is None:
exclude_exceptions = []
if exclude_exceptions is None: exclude_exceptions = []
try:
from circus.exc import ConflictError
if ConflictError not in exclude_exceptions:
exclude_exceptions.append(ConflictError)
if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError)
except ImportError:
pass
super(ExceptionFilter, self).__init__(**kwargs)
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
def filter(self, record: logging.LogRecord) -> bool:
if record.exc_info:
etype, _, _ = record.exc_info
if etype is not None:
for exc in self.EXCLUDE_EXCEPTIONS:
if issubclass(etype, exc):
return False
if issubclass(etype, exc): return False
return True
class InfoFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return logging.INFO <= record.levelno < logging.WARNING
def filter(self, record: logging.LogRecord) -> bool: return logging.INFO <= record.levelno < logging.WARNING
class WarningFilter(logging.Filter): # FIXME: Why does this not work?
def filter(self, record: logging.LogRecord) -> bool:
if get_disable_warnings():
return record.levelno >= logging.ERROR
if get_disable_warnings(): return record.levelno >= logging.ERROR
return True
_LOGGING_CONFIG = {
'version': 1,
'disable_existing_loggers': True,
@@ -274,8 +173,6 @@ _LOGGING_CONFIG = {
},
'root': {'level': logging.WARNING},
}
def configure_logging():
if get_quiet_mode():
_LOGGING_CONFIG['loggers']['openllm']['level'] = logging.ERROR
@@ -295,8 +192,6 @@ def configure_logging():
_LOGGING_CONFIG['loggers']['openllm']['level'] = logging.ERROR
logging.config.dictConfig(_LOGGING_CONFIG)
# XXX: define all classes, functions import above this line
# since _extras will be the locals() import from this file.
_extras = {
@@ -332,6 +227,7 @@ __lazy = LazyModule(
'is_ctranslate_available',
'is_transformers_available',
'is_autoawq_available',
'is_flash_attn_2_available',
'is_bentoml_available',
'is_triton_available',
],

View File

@@ -16,6 +16,7 @@ from openllm_core.utils.import_utils import (
is_jupytext_available as is_jupytext_available,
is_notebook_available as is_notebook_available,
is_peft_available as is_peft_available,
is_flash_attn_2_available as is_flash_attn_2_available,
is_torch_available as is_torch_available,
is_transformers_available as is_transformers_available,
is_vllm_available as is_vllm_available,

View File

@@ -15,8 +15,6 @@ OPTIONAL_DEPENDENCIES = {
ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
USE_VLLM = os.getenv('USE_VLLM', 'AUTO').upper()
def _is_package_available(package: str) -> bool:
_package_available = importlib.util.find_spec(package) is not None
if _package_available:
@@ -25,8 +23,6 @@ def _is_package_available(package: str) -> bool:
except importlib.metadata.PackageNotFoundError:
_package_available = False
return _package_available
_ctranslate_available = importlib.util.find_spec('ctranslate2') is not None
_vllm_available = importlib.util.find_spec('vllm') is not None
_grpc_available = importlib.util.find_spec('grpc') is not None
@@ -37,60 +33,24 @@ _transformers_available = _is_package_available('transformers')
_bentoml_available = _is_package_available('bentoml')
_peft_available = _is_package_available('peft')
_bitsandbytes_available = _is_package_available('bitsandbytes')
_flash_attn_available = _is_package_available('flash_attn')
_jupyter_available = _is_package_available('jupyter')
_jupytext_available = _is_package_available('jupytext')
_notebook_available = _is_package_available('notebook')
_autogptq_available = _is_package_available('auto_gptq')
def is_triton_available() -> bool:
return _triton_available
def is_ctranslate_available() -> bool:
return _ctranslate_available
def is_bentoml_available() -> bool:
return _bentoml_available # needs this since openllm-core doesn't explicitly depends on bentoml
def is_transformers_available() -> bool:
return _transformers_available # needs this since openllm-core doesn't explicitly depends on transformers
def is_grpc_available() -> bool:
return _grpc_available
def is_jupyter_available() -> bool:
return _jupyter_available
def is_jupytext_available() -> bool:
return _jupytext_available
def is_notebook_available() -> bool:
return _notebook_available
def is_peft_available() -> bool:
return _peft_available
def is_bitsandbytes_available() -> bool:
return _bitsandbytes_available
def is_autogptq_available() -> bool:
return _autogptq_available
def is_torch_available() -> bool:
return _torch_available
def is_triton_available() -> bool: return _triton_available
def is_ctranslate_available() -> bool: return _ctranslate_available
def is_bentoml_available() -> bool: return _bentoml_available # needs this since openllm-core doesn't explicitly depends on bentoml
def is_transformers_available() -> bool: return _transformers_available # needs this since openllm-core doesn't explicitly depends on transformers
def is_grpc_available() -> bool: return _grpc_available
def is_jupyter_available() -> bool: return _jupyter_available
def is_jupytext_available() -> bool: return _jupytext_available
def is_notebook_available() -> bool: return _notebook_available
def is_peft_available() -> bool: return _peft_available
def is_bitsandbytes_available() -> bool: return _bitsandbytes_available
def is_autogptq_available() -> bool: return _autogptq_available
def is_torch_available() -> bool: return _torch_available
def is_flash_attn_2_available() -> bool: return _flash_attn_available
def is_autoawq_available() -> bool:
global _autoawq_available
try:
@@ -98,8 +58,6 @@ def is_autoawq_available() -> bool:
except importlib.metadata.PackageNotFoundError:
_autoawq_available = False
return _autoawq_available
def is_vllm_available() -> bool:
global _vllm_available
if USE_VLLM in ENV_VARS_TRUE_AND_AUTO_VALUES or _vllm_available:

View File

@@ -83,8 +83,7 @@ class CTranslateRunnable(bentoml.Runnable):
SUPPORTS_CPU_MULTI_THREADING = True
def __init__(self, llm):
if not is_ctranslate_available():
raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
if not is_ctranslate_available(): raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
@bentoml.Runnable.method(batchable=False)
@@ -127,8 +126,7 @@ class vLLMRunnable(bentoml.Runnable):
SUPPORTS_CPU_MULTI_THREADING = True
def __init__(self, llm):
if not is_vllm_available():
raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
if not is_vllm_available(): raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
import vllm
self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer
@@ -180,13 +178,8 @@ class PyTorchRunnable(bentoml.Runnable):
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
if adapter_name is not None:
self.model.set_adapter(adapter_name)
async for generation_output in self.forward(prompt_token_ids, request_id, list(stop), **attrs):
yield generation_output.model_dump_json()
async def forward(self, prompt_token_ids, request_id, stop, **attrs):
from ._generation import get_context_length, is_partial_stop, prepare_logits_processor
from ._generation import get_context_length, prepare_logits_processor
if adapter_name is not None: self.model.set_adapter(adapter_name)
max_new_tokens = attrs.pop('max_new_tokens', 256)
context_length = attrs.pop('context_length', None)
@@ -224,6 +217,8 @@ class PyTorchRunnable(bentoml.Runnable):
finish_reason = None
prompt_logprobs = []
prompt_token_indices = []
stopped = False
sample_logprobs: SampleLogprobs = [None] # The first token has no logprobs
for i in range(config['max_new_tokens']):
if i == 0: # prefill
@@ -247,10 +242,9 @@ class PyTorchRunnable(bentoml.Runnable):
)
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
if config['repetition_penalty'] > 1.0:
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.model.device)
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
@@ -272,11 +266,11 @@ class PyTorchRunnable(bentoml.Runnable):
token = tokens[0]
output_token_ids.append(token)
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
sample_logprobs: SampleLogprobs = []
token_logprobs = logprobs[token].item()
cumulative_logprob += token_logprobs
if config['logprobs']:
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
token_logprobs = logprobs[token].item()
cumulative_logprob += token_logprobs
if config['prompt_logprobs']:
for token_id in prompt_token_ids:
@@ -296,40 +290,32 @@ class PyTorchRunnable(bentoml.Runnable):
clean_up_tokenization_spaces=True,
)
partially_stopped = False
if len(stop) > 0:
for it in stop:
pos = text.rfind(it, rfind_start)
if pos != -1:
text, stopped = text[:pos], True
break
else:
partially_stopped = is_partial_stop(text, it)
if partially_stopped:
break
if config['logprobs']:
sample_logprobs.append({token: token_logprobs})
if config['logprobs']: sample_logprobs.append({token: token_logprobs})
if not partially_stopped:
# TODO: calculate prompt_logprobs
yield GenerationOutput(
prompt='',
finished=False,
outputs=[
CompletionChunk(
index=0,
text=text,
token_ids=tmp_output_ids,
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=None,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
request_id=request_id,
)
yield GenerationOutput(
prompt='',
finished=False,
outputs=[
CompletionChunk(
index=0,
text=text,
token_ids=tmp_output_ids,
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=None,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
request_id=request_id,
).model_dump_json()
if stopped:
break
else:
@@ -343,16 +329,16 @@ class PyTorchRunnable(bentoml.Runnable):
CompletionChunk(
index=0,
text=text,
token_ids=output_token_ids[input_len:],
token_ids=output_token_ids,
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=finish_reason,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
request_id=request_id,
)
).model_dump_json()
# Clean
del past_key_values, out

View File

@@ -4,7 +4,7 @@ import orjson, torch, transformers, bentoml, openllm
from huggingface_hub import snapshot_download
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import first_not_none, is_autogptq_available
from openllm_core.utils import first_not_none, is_autogptq_available, is_flash_attn_2_available
from ._helpers import get_tokenizer, infer_autoclass_from_llm, process_config
from .weights import HfIgnore
@@ -122,18 +122,36 @@ def load_model(llm, *decls, **attrs):
if llm.config['model_type'] != 'causal_lm':
raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
# TODO: investigate load with flash attention
model = auto_class.from_pretrained(
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
)
try:
model = auto_class.from_pretrained(
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, use_flash_attention_2=is_flash_attn_2_available(), **attrs
)
except Exception as err:
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
model = auto_class.from_pretrained(
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
)
else:
model = auto_class.from_pretrained(
llm.bentomodel.path,
*decls,
config=config,
trust_remote_code=llm.trust_remote_code,
device_map=device_map,
**attrs,
)
try:
model = auto_class.from_pretrained(
llm.bentomodel.path,
*decls,
config=config,
trust_remote_code=llm.trust_remote_code,
device_map=device_map,
use_flash_attention_2=is_flash_attn_2_available(),
**attrs,
)
except Exception as err:
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
model = auto_class.from_pretrained(
llm.bentomodel.path,
*decls,
config=config,
trust_remote_code=llm.trust_remote_code,
device_map=device_map,
**attrs,
)
check_unintialised_params(model)
return model

View File

@@ -35,6 +35,7 @@ from openllm_core.utils import (
is_bentoml_available as is_bentoml_available,
is_bitsandbytes_available as is_bitsandbytes_available,
is_ctranslate_available as is_ctranslate_available,
is_flash_attn_2_available as is_flash_attn_2_available,
is_grpc_available as is_grpc_available,
is_jupyter_available as is_jupyter_available,
is_jupytext_available as is_jupytext_available,
@@ -55,7 +56,7 @@ from openllm_core.utils import (
)
from openllm_core.utils.serde import converter as converter
from .._llm import LLM
from ._llm import LLM
def available_devices() -> Tuple[str, ...]: ...
def device_count() -> int: ...

View File

@@ -1237,7 +1237,7 @@ def prune_command(
for b in bentoml.bentos.list()
if 'start_name' in b.info.labels and b.info.labels['start_name'] == inflection.underscore(model_name)
]
if model_name is None:
else:
available += [
(b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels
]
@@ -1342,11 +1342,11 @@ def query_command(
if stream:
stream_res: t.Iterator[StreamingResponse] = client.generate_stream(prompt, **_memoized)
termui.echo(prompt, fg=input_fg, nl=False if stream else True)
termui.echo(prompt, fg=input_fg, nl=False)
for it in stream_res:
termui.echo(it.text, fg=generated_fg, nl=False)
else:
termui.echo(prompt, fg=input_fg, nl=False)
termui.echo(prompt, fg=input_fg, nl=True)
termui.echo(client.generate(prompt, **_memoized).outputs[0].text, fg=generated_fg, nl=False)
ctx.exit(0)