mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-12 10:29:36 -04:00
feat(llm): update warning envvar and add embedded mode (#618)
* chore: unify warning envvar and update type inference Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore; update documentation about embedded Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -87,9 +87,7 @@ def Runner(
|
||||
)
|
||||
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'))
|
||||
llm = LLM[t.Any, t.Any](backend=backend, llm_config=llm_config, **attrs)
|
||||
if init_local:
|
||||
llm.runner.init_local(quiet=True)
|
||||
llm = LLM[t.Any, t.Any](backend=backend, llm_config=llm_config, embedded=init_local, **attrs)
|
||||
return llm.runner
|
||||
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ from openllm_core._typing_compat import TupleAny
|
||||
from openllm_core.exceptions import MissingDependencyError
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.utils import DEBUG
|
||||
from openllm_core.utils import LazyLoader
|
||||
from openllm_core.utils import ReprMixin
|
||||
from openllm_core.utils import apply
|
||||
from openllm_core.utils import check_bool_env
|
||||
@@ -43,6 +42,9 @@ from openllm_core.utils import converter
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import flatten_attrs
|
||||
from openllm_core.utils import generate_hash_from_file
|
||||
from openllm_core.utils import get_debug_mode
|
||||
from openllm_core.utils import get_disable_warnings
|
||||
from openllm_core.utils import get_quiet_mode
|
||||
from openllm_core.utils import is_peft_available
|
||||
from openllm_core.utils import resolve_filepath
|
||||
from openllm_core.utils import validate_is_path
|
||||
@@ -55,9 +57,13 @@ from .serialisation.constants import PEFT_CONFIG_NAME
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import peft
|
||||
import transformers
|
||||
|
||||
from peft.config import PeftConfig
|
||||
from peft.peft_model import PeftModel
|
||||
from peft.peft_model import PeftModelForCausalLM
|
||||
from peft.peft_model import PeftModelForSeq2SeqLM
|
||||
|
||||
from bentoml._internal.runner.runnable import RunnableMethod
|
||||
from bentoml._internal.runner.runner import RunnerMethod
|
||||
from bentoml._internal.runner.runner_handle import RunnerHandle
|
||||
@@ -65,10 +71,7 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core.utils.representation import ReprArgs
|
||||
|
||||
else:
|
||||
peft = LazyLoader('peft', globals(), 'peft')
|
||||
|
||||
ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['peft.PeftConfig', str]]]
|
||||
ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['PeftConfig', str]]]
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
@@ -159,6 +162,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
serialisation: LiteralSerialisation = 'safetensors',
|
||||
trust_remote_code: bool = False,
|
||||
embedded: bool = False,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
# low_cpu_mem_usage is only available for model this is helpful on system with low memory to avoid OOM
|
||||
@@ -215,6 +219,14 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
# resolve the tag
|
||||
self._tag = model.tag
|
||||
|
||||
if embedded and not get_disable_warnings() and not get_quiet_mode():
|
||||
logger.warning(
|
||||
'You are using embedded mode, which means the models will be loaded into memory. This is often not recommended in production and should only be used for local development only.'
|
||||
)
|
||||
if not get_debug_mode():
|
||||
logger.info("To disable this warning, set 'OPENLLM_DISABLE_WARNING=True'")
|
||||
self.runner.init_local(quiet=True)
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self, model_id, model_version, backend) -> tuple[str, str | None]:
|
||||
model_id, *maybe_revision = model_id.rsplit(':')
|
||||
@@ -401,9 +413,9 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
|
||||
def prepare_for_training(
|
||||
self, adapter_type: AdapterType = 'lora', use_gradient_checking: bool = True, **attrs: t.Any
|
||||
) -> tuple[peft.PeftModel | peft.PeftModelForCausalLM | peft.PeftModelForSeq2SeqLM, T]:
|
||||
from peft import get_peft_model
|
||||
from peft import prepare_model_for_kbit_training
|
||||
) -> tuple[PeftModel | PeftModelForCausalLM | PeftModelForSeq2SeqLM, T]:
|
||||
from peft.mapping import get_peft_model
|
||||
from peft.utils.other import prepare_model_for_kbit_training
|
||||
|
||||
peft_config = (
|
||||
self.config['fine_tune_strategies']
|
||||
|
||||
@@ -18,8 +18,10 @@ import openllm_core
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm_core._typing_compat import LiteralSerialisation
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import WARNING_ENV_VAR
|
||||
from openllm_core.utils import codegen
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import get_disable_warnings
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
|
||||
@@ -197,11 +199,8 @@ def _build(
|
||||
model_id,
|
||||
'--machine',
|
||||
'--serialisation',
|
||||
t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
]
|
||||
if quantize:
|
||||
@@ -237,7 +236,11 @@ def _build(
|
||||
args.extend(['--container-registry', container_registry, '--container-version-strategy', container_version_strategy])
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
if force_push:
|
||||
args.append('--force-push')
|
||||
|
||||
current_disable_warning = get_disable_warnings()
|
||||
os.environ[WARNING_ENV_VAR] = str(True)
|
||||
try:
|
||||
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
|
||||
except subprocess.CalledProcessError as e:
|
||||
@@ -250,6 +253,7 @@ def _build(
|
||||
raise ValueError(
|
||||
f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
)
|
||||
os.environ[WARNING_ENV_VAR] = str(current_disable_warning)
|
||||
try:
|
||||
result = orjson.loads(matched.group(1))
|
||||
except orjson.JSONDecodeError as e:
|
||||
|
||||
@@ -77,6 +77,7 @@ from openllm_core.utils import compose
|
||||
from openllm_core.utils import configure_logging
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import get_debug_mode
|
||||
from openllm_core.utils import get_disable_warnings
|
||||
from openllm_core.utils import get_quiet_mode
|
||||
from openllm_core.utils import is_torch_available
|
||||
from openllm_core.utils import resolve_user_filepath
|
||||
@@ -141,19 +142,22 @@ _object_setattr = object.__setattr__
|
||||
_EXT_FOLDER = os.path.abspath(os.path.join(os.path.dirname(__file__), 'extension'))
|
||||
|
||||
|
||||
def backend_warning(backend: LiteralBackend):
|
||||
if backend == 'pt' and check_bool_env('OPENLLM_BACKEND_WARNING') and not get_quiet_mode():
|
||||
def backend_warning(backend: LiteralBackend, build: bool = False) -> None:
|
||||
if backend == 'pt' and (not get_disable_warnings()) and not get_quiet_mode():
|
||||
if openllm.utils.is_vllm_available():
|
||||
termui.warning(
|
||||
'\nvLLM is available, but using PyTorch backend instead. Note that vLLM is a lot more performant and should always be used in production (by explicitly set --backend vllm).'
|
||||
'vLLM is available, but using PyTorch backend instead. Note that vLLM is a lot more performant and should always be used in production (by explicitly set --backend vllm).'
|
||||
)
|
||||
else:
|
||||
termui.warning(
|
||||
'\nvLLM is not available. Note that PyTorch backend is not as performant as vLLM and you should always consider using vLLM for production.'
|
||||
'vLLM is not available. Note that PyTorch backend is not as performant as vLLM and you should always consider using vLLM for production.'
|
||||
)
|
||||
termui.debug(
|
||||
content="\nTip: if you are running 'openllm build' you can set '--backend vllm' to package your Bento with vLLM backend. To hide these messages, set 'OPENLLM_BACKEND_WARNING=False'\n"
|
||||
)
|
||||
if build:
|
||||
termui.info(
|
||||
"Tip: You can set '--backend vllm' to package your Bento with vLLM backend regardless if vLLM is available locally."
|
||||
)
|
||||
if not get_debug_mode():
|
||||
termui.info("To disable these warnings, set 'OPENLLM_DISABLE_WARNING=True'")
|
||||
|
||||
|
||||
class Extensions(click.MultiCommand):
|
||||
@@ -425,13 +429,14 @@ def start_command(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
)
|
||||
if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
|
||||
termui.warning(
|
||||
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation."
|
||||
)
|
||||
if serialisation == 'safetensors' and quantize is not None and not get_disable_warnings() and not get_quiet_mode():
|
||||
termui.warning(f"'--quantize={quantize}' might not work with 'safetensors' serialisation format.")
|
||||
termui.warning(
|
||||
f"Make sure to check out '{model_id}' repository to see if the weights is in '{serialisation}' format if unsure."
|
||||
)
|
||||
termui.info("Tip: You can always fallback to '--serialisation legacy' when running quantisation.")
|
||||
if not get_debug_mode():
|
||||
termui.info("To disable these warnings, set 'OPENLLM_DISABLE_WARNING=True'")
|
||||
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=model_id,
|
||||
@@ -542,19 +547,17 @@ def start_grpc_command(
|
||||
|
||||
from ..serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
serialisation = t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
serialisation = first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
)
|
||||
if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
|
||||
termui.warning(
|
||||
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation."
|
||||
)
|
||||
if serialisation == 'safetensors' and quantize is not None and not get_disable_warnings() and not get_quiet_mode():
|
||||
termui.warning(f"'--quantize={quantize}' might not work with 'safetensors' serialisation format.")
|
||||
termui.warning(
|
||||
f"Make sure to check out '{model_id}' repository to see if the weights is in '{serialisation}' format if unsure."
|
||||
)
|
||||
termui.info("Tip: You can always fallback to '--serialisation legacy' when running quantisation.")
|
||||
if not get_debug_mode():
|
||||
termui.info("To disable these warnings, set 'OPENLLM_DISABLE_WARNING=True'")
|
||||
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=model_id,
|
||||
@@ -824,9 +827,26 @@ def import_command(
|
||||
return response
|
||||
|
||||
|
||||
class DeploymentInstruction(t.TypedDict):
|
||||
@attr.define(auto_attribs=True)
|
||||
class _Content:
|
||||
instr: str
|
||||
cmd: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.instr.format(cmd=self.cmd)
|
||||
|
||||
|
||||
@attr.define(auto_attribs=True)
|
||||
class DeploymentInstruction:
|
||||
type: t.Literal['container', 'bentocloud']
|
||||
content: str
|
||||
content: _Content
|
||||
|
||||
@classmethod
|
||||
def from_content(cls, type: t.Literal['container', 'bentocloud'], instr: str, cmd: str) -> DeploymentInstruction:
|
||||
return cls(type=type, content=_Content(instr=instr, cmd=cmd))
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class BuildBentoOutput(t.TypedDict):
|
||||
@@ -985,7 +1005,7 @@ def build_command(
|
||||
),
|
||||
),
|
||||
)
|
||||
backend_warning(llm.__llm_backend__)
|
||||
backend_warning(llm.__llm_backend__, build=True)
|
||||
|
||||
os.environ.update(
|
||||
{
|
||||
@@ -1069,21 +1089,36 @@ def build_command(
|
||||
traceback.print_exc()
|
||||
raise click.ClickException('Exception caught while building BentoLLM:\n' + str(err)) from err
|
||||
|
||||
def get_current_bentocloud_context() -> str:
|
||||
passed = t.cast(t.Optional[str], ctx.obj.cloud_context)
|
||||
if passed:
|
||||
return passed
|
||||
else:
|
||||
return t.cast(
|
||||
str, orjson.loads(subprocess.check_output(['bentoml', 'cloud', 'current-context'], env=os.environ))['name']
|
||||
)
|
||||
|
||||
response = BuildBentoOutput(
|
||||
state=state,
|
||||
tag=str(bento_tag),
|
||||
backend=llm.__llm_backend__,
|
||||
instructions=[
|
||||
DeploymentInstruction(
|
||||
type='bentocloud', content=f"Push to BentoCloud with 'bentoml push': `bentoml push {bento_tag}`"
|
||||
DeploymentInstruction.from_content(
|
||||
type='bentocloud',
|
||||
instr="☁️ Push to BentoCloud with 'bentoml push':\n $ {cmd}",
|
||||
cmd=f'bentoml push {bento_tag} --context {get_current_bentocloud_context()}',
|
||||
),
|
||||
DeploymentInstruction(
|
||||
DeploymentInstruction.from_content(
|
||||
type='container',
|
||||
content=f"Container BentoLLM with 'bentoml containerize': `bentoml containerize {bento_tag} --opt progress=plain`",
|
||||
instr="🐳 Container BentoLLM with 'bentoml containerize':\n $ {cmd}",
|
||||
cmd=f'bentoml containerize {bento_tag} --opt progress=plain',
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
plain_instruction = {i.type: i['content'].cmd for i in response['instructions']}
|
||||
if machine or get_debug_mode():
|
||||
response['instructions'] = plain_instruction
|
||||
if machine:
|
||||
termui.echo(f'__object__:{orjson.dumps(response).decode()}\n\n', fg='white')
|
||||
elif not get_quiet_mode() and (not push or not containerize):
|
||||
@@ -1093,9 +1128,9 @@ def build_command(
|
||||
termui.warning(f"Bento for '{model_id}' already exists [{bento}]. To overwrite it pass '--overwrite'.\n")
|
||||
if not get_debug_mode():
|
||||
termui.echo(OPENLLM_FIGLET)
|
||||
termui.echo('\n📖 Next steps:\n\n', nl=False)
|
||||
termui.echo('📖 Next steps:\n', nl=False)
|
||||
for instruction in response['instructions']:
|
||||
termui.echo(f"* {instruction['content']}\n", nl=False)
|
||||
termui.echo(f" * {instruction['content']}\n", nl=False)
|
||||
|
||||
if push:
|
||||
BentoMLContainer.bentocloud_client.get().push_bento(
|
||||
@@ -1112,7 +1147,6 @@ def build_command(
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err
|
||||
|
||||
response.pop('instructions')
|
||||
if get_debug_mode():
|
||||
termui.echo('\n' + orjson.dumps(response).decode(), fg=None)
|
||||
return response
|
||||
|
||||
@@ -11,7 +11,6 @@ import orjson
|
||||
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core.utils import get_debug_mode
|
||||
from openllm_core.utils import get_quiet_mode
|
||||
|
||||
|
||||
logger = logging.getLogger('openllm')
|
||||
@@ -53,7 +52,10 @@ class JsonLog(t.TypedDict):
|
||||
|
||||
|
||||
def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None:
|
||||
echo(orjson.dumps(JsonLog(log_level=level, content=content)).decode(), fg=fg, json=True)
|
||||
if get_debug_mode():
|
||||
echo(content, fg=fg)
|
||||
else:
|
||||
echo(orjson.dumps(JsonLog(log_level=level, content=content)).decode(), fg=fg, json=True)
|
||||
|
||||
|
||||
warning = functools.partial(log, level=Level.WARNING)
|
||||
@@ -64,7 +66,7 @@ info = functools.partial(log, level=Level.INFO)
|
||||
notset = functools.partial(log, level=Level.NOTSET)
|
||||
|
||||
|
||||
def echo(text: t.Any, fg: str | None = None, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
|
||||
def echo(text: t.Any, fg: str | None = None, *, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
|
||||
if json:
|
||||
text = orjson.loads(text)
|
||||
if 'content' in text and 'log_level' in text:
|
||||
@@ -77,8 +79,7 @@ def echo(text: t.Any, fg: str | None = None, _with_style: bool = True, json: boo
|
||||
content = t.cast(str, text)
|
||||
attrs['fg'] = fg
|
||||
|
||||
if not get_quiet_mode():
|
||||
t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(content, **attrs)
|
||||
(click.echo if not _with_style else click.secho)(content, **attrs)
|
||||
|
||||
|
||||
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
|
||||
|
||||
@@ -40,6 +40,7 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core.utils import generate_context as generate_context
|
||||
from openllm_core.utils import generate_hash_from_file as generate_hash_from_file
|
||||
from openllm_core.utils import get_debug_mode as get_debug_mode
|
||||
from openllm_core.utils import get_disable_warnings as get_disable_warnings
|
||||
from openllm_core.utils import get_quiet_mode as get_quiet_mode
|
||||
from openllm_core.utils import in_notebook as in_notebook
|
||||
from openllm_core.utils import is_autoawq_available as is_autoawq_available
|
||||
@@ -61,6 +62,7 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core.utils import resolve_user_filepath as resolve_user_filepath
|
||||
from openllm_core.utils import serde as serde
|
||||
from openllm_core.utils import set_debug_mode as set_debug_mode
|
||||
from openllm_core.utils import set_disable_warnings as set_disable_warnings
|
||||
from openllm_core.utils import set_quiet_mode as set_quiet_mode
|
||||
from openllm_core.utils import validate_is_path as validate_is_path
|
||||
from openllm_core.utils.serde import converter as converter
|
||||
|
||||
Reference in New Issue
Block a user