feat(llm): update warning envvar and add embedded mode (#618)

* chore: unify warning envvar and update type inference

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore; update documentation about embedded

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-12 17:39:06 -05:00
committed by GitHub
parent 7e1fb35a71
commit c3416c0afd
9 changed files with 140 additions and 54 deletions

View File

@@ -87,9 +87,7 @@ def Runner(
)
backend = t.cast(LiteralBackend, first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'))
llm = LLM[t.Any, t.Any](backend=backend, llm_config=llm_config, **attrs)
if init_local:
llm.runner.init_local(quiet=True)
llm = LLM[t.Any, t.Any](backend=backend, llm_config=llm_config, embedded=init_local, **attrs)
return llm.runner

View File

@@ -34,7 +34,6 @@ from openllm_core._typing_compat import TupleAny
from openllm_core.exceptions import MissingDependencyError
from openllm_core.prompts import PromptTemplate
from openllm_core.utils import DEBUG
from openllm_core.utils import LazyLoader
from openllm_core.utils import ReprMixin
from openllm_core.utils import apply
from openllm_core.utils import check_bool_env
@@ -43,6 +42,9 @@ from openllm_core.utils import converter
from openllm_core.utils import first_not_none
from openllm_core.utils import flatten_attrs
from openllm_core.utils import generate_hash_from_file
from openllm_core.utils import get_debug_mode
from openllm_core.utils import get_disable_warnings
from openllm_core.utils import get_quiet_mode
from openllm_core.utils import is_peft_available
from openllm_core.utils import resolve_filepath
from openllm_core.utils import validate_is_path
@@ -55,9 +57,13 @@ from .serialisation.constants import PEFT_CONFIG_NAME
if t.TYPE_CHECKING:
import peft
import transformers
from peft.config import PeftConfig
from peft.peft_model import PeftModel
from peft.peft_model import PeftModelForCausalLM
from peft.peft_model import PeftModelForSeq2SeqLM
from bentoml._internal.runner.runnable import RunnableMethod
from bentoml._internal.runner.runner import RunnerMethod
from bentoml._internal.runner.runner_handle import RunnerHandle
@@ -65,10 +71,7 @@ if t.TYPE_CHECKING:
from openllm_core._configuration import LLMConfig
from openllm_core.utils.representation import ReprArgs
else:
peft = LazyLoader('peft', globals(), 'peft')
ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['peft.PeftConfig', str]]]
ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['PeftConfig', str]]]
P = ParamSpec('P')
@@ -159,6 +162,7 @@ class LLM(t.Generic[M, T], ReprMixin):
adapter_map: dict[str, str] | None = None,
serialisation: LiteralSerialisation = 'safetensors',
trust_remote_code: bool = False,
embedded: bool = False,
**attrs: t.Any,
):
# low_cpu_mem_usage is only available for model this is helpful on system with low memory to avoid OOM
@@ -215,6 +219,14 @@ class LLM(t.Generic[M, T], ReprMixin):
# resolve the tag
self._tag = model.tag
if embedded and not get_disable_warnings() and not get_quiet_mode():
logger.warning(
'You are using embedded mode, which means the models will be loaded into memory. This is often not recommended in production and should only be used for local development only.'
)
if not get_debug_mode():
logger.info("To disable this warning, set 'OPENLLM_DISABLE_WARNING=True'")
self.runner.init_local(quiet=True)
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
def _make_tag_components(self, model_id, model_version, backend) -> tuple[str, str | None]:
model_id, *maybe_revision = model_id.rsplit(':')
@@ -401,9 +413,9 @@ class LLM(t.Generic[M, T], ReprMixin):
def prepare_for_training(
self, adapter_type: AdapterType = 'lora', use_gradient_checking: bool = True, **attrs: t.Any
) -> tuple[peft.PeftModel | peft.PeftModelForCausalLM | peft.PeftModelForSeq2SeqLM, T]:
from peft import get_peft_model
from peft import prepare_model_for_kbit_training
) -> tuple[PeftModel | PeftModelForCausalLM | PeftModelForSeq2SeqLM, T]:
from peft.mapping import get_peft_model
from peft.utils.other import prepare_model_for_kbit_training
peft_config = (
self.config['fine_tune_strategies']

View File

@@ -18,8 +18,10 @@ import openllm_core
from bentoml._internal.configuration.containers import BentoMLContainer
from openllm_core._typing_compat import LiteralSerialisation
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import WARNING_ENV_VAR
from openllm_core.utils import codegen
from openllm_core.utils import first_not_none
from openllm_core.utils import get_disable_warnings
from openllm_core.utils import is_vllm_available
@@ -197,11 +199,8 @@ def _build(
model_id,
'--machine',
'--serialisation',
t.cast(
LiteralSerialisation,
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
]
if quantize:
@@ -237,7 +236,11 @@ def _build(
args.extend(['--container-registry', container_registry, '--container-version-strategy', container_version_strategy])
if additional_args:
args.extend(additional_args)
if force_push:
args.append('--force-push')
current_disable_warning = get_disable_warnings()
os.environ[WARNING_ENV_VAR] = str(True)
try:
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
except subprocess.CalledProcessError as e:
@@ -250,6 +253,7 @@ def _build(
raise ValueError(
f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
)
os.environ[WARNING_ENV_VAR] = str(current_disable_warning)
try:
result = orjson.loads(matched.group(1))
except orjson.JSONDecodeError as e:

View File

@@ -77,6 +77,7 @@ from openllm_core.utils import compose
from openllm_core.utils import configure_logging
from openllm_core.utils import first_not_none
from openllm_core.utils import get_debug_mode
from openllm_core.utils import get_disable_warnings
from openllm_core.utils import get_quiet_mode
from openllm_core.utils import is_torch_available
from openllm_core.utils import resolve_user_filepath
@@ -141,19 +142,22 @@ _object_setattr = object.__setattr__
_EXT_FOLDER = os.path.abspath(os.path.join(os.path.dirname(__file__), 'extension'))
def backend_warning(backend: LiteralBackend):
if backend == 'pt' and check_bool_env('OPENLLM_BACKEND_WARNING') and not get_quiet_mode():
def backend_warning(backend: LiteralBackend, build: bool = False) -> None:
if backend == 'pt' and (not get_disable_warnings()) and not get_quiet_mode():
if openllm.utils.is_vllm_available():
termui.warning(
'\nvLLM is available, but using PyTorch backend instead. Note that vLLM is a lot more performant and should always be used in production (by explicitly set --backend vllm).'
'vLLM is available, but using PyTorch backend instead. Note that vLLM is a lot more performant and should always be used in production (by explicitly set --backend vllm).'
)
else:
termui.warning(
'\nvLLM is not available. Note that PyTorch backend is not as performant as vLLM and you should always consider using vLLM for production.'
'vLLM is not available. Note that PyTorch backend is not as performant as vLLM and you should always consider using vLLM for production.'
)
termui.debug(
content="\nTip: if you are running 'openllm build' you can set '--backend vllm' to package your Bento with vLLM backend. To hide these messages, set 'OPENLLM_BACKEND_WARNING=False'\n"
)
if build:
termui.info(
"Tip: You can set '--backend vllm' to package your Bento with vLLM backend regardless if vLLM is available locally."
)
if not get_debug_mode():
termui.info("To disable these warnings, set 'OPENLLM_DISABLE_WARNING=True'")
class Extensions(click.MultiCommand):
@@ -425,13 +429,14 @@ def start_command(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
)
if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
termui.warning(
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation."
)
if serialisation == 'safetensors' and quantize is not None and not get_disable_warnings() and not get_quiet_mode():
termui.warning(f"'--quantize={quantize}' might not work with 'safetensors' serialisation format.")
termui.warning(
f"Make sure to check out '{model_id}' repository to see if the weights is in '{serialisation}' format if unsure."
)
termui.info("Tip: You can always fallback to '--serialisation legacy' when running quantisation.")
if not get_debug_mode():
termui.info("To disable these warnings, set 'OPENLLM_DISABLE_WARNING=True'")
llm = openllm.LLM[t.Any, t.Any](
model_id=model_id,
@@ -542,19 +547,17 @@ def start_grpc_command(
from ..serialisation.transformers.weights import has_safetensors_weights
serialisation = t.cast(
LiteralSerialisation,
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
serialisation = first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
)
if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
termui.warning(
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation."
)
if serialisation == 'safetensors' and quantize is not None and not get_disable_warnings() and not get_quiet_mode():
termui.warning(f"'--quantize={quantize}' might not work with 'safetensors' serialisation format.")
termui.warning(
f"Make sure to check out '{model_id}' repository to see if the weights is in '{serialisation}' format if unsure."
)
termui.info("Tip: You can always fallback to '--serialisation legacy' when running quantisation.")
if not get_debug_mode():
termui.info("To disable these warnings, set 'OPENLLM_DISABLE_WARNING=True'")
llm = openllm.LLM[t.Any, t.Any](
model_id=model_id,
@@ -824,9 +827,26 @@ def import_command(
return response
class DeploymentInstruction(t.TypedDict):
@attr.define(auto_attribs=True)
class _Content:
instr: str
cmd: str
def __str__(self) -> str:
return self.instr.format(cmd=self.cmd)
@attr.define(auto_attribs=True)
class DeploymentInstruction:
type: t.Literal['container', 'bentocloud']
content: str
content: _Content
@classmethod
def from_content(cls, type: t.Literal['container', 'bentocloud'], instr: str, cmd: str) -> DeploymentInstruction:
return cls(type=type, content=_Content(instr=instr, cmd=cmd))
def __getitem__(self, key: str) -> str:
return getattr(self, key)
class BuildBentoOutput(t.TypedDict):
@@ -985,7 +1005,7 @@ def build_command(
),
),
)
backend_warning(llm.__llm_backend__)
backend_warning(llm.__llm_backend__, build=True)
os.environ.update(
{
@@ -1069,21 +1089,36 @@ def build_command(
traceback.print_exc()
raise click.ClickException('Exception caught while building BentoLLM:\n' + str(err)) from err
def get_current_bentocloud_context() -> str:
passed = t.cast(t.Optional[str], ctx.obj.cloud_context)
if passed:
return passed
else:
return t.cast(
str, orjson.loads(subprocess.check_output(['bentoml', 'cloud', 'current-context'], env=os.environ))['name']
)
response = BuildBentoOutput(
state=state,
tag=str(bento_tag),
backend=llm.__llm_backend__,
instructions=[
DeploymentInstruction(
type='bentocloud', content=f"Push to BentoCloud with 'bentoml push': `bentoml push {bento_tag}`"
DeploymentInstruction.from_content(
type='bentocloud',
instr="☁️ Push to BentoCloud with 'bentoml push':\n $ {cmd}",
cmd=f'bentoml push {bento_tag} --context {get_current_bentocloud_context()}',
),
DeploymentInstruction(
DeploymentInstruction.from_content(
type='container',
content=f"Container BentoLLM with 'bentoml containerize': `bentoml containerize {bento_tag} --opt progress=plain`",
instr="🐳 Container BentoLLM with 'bentoml containerize':\n $ {cmd}",
cmd=f'bentoml containerize {bento_tag} --opt progress=plain',
),
],
)
plain_instruction = {i.type: i['content'].cmd for i in response['instructions']}
if machine or get_debug_mode():
response['instructions'] = plain_instruction
if machine:
termui.echo(f'__object__:{orjson.dumps(response).decode()}\n\n', fg='white')
elif not get_quiet_mode() and (not push or not containerize):
@@ -1093,9 +1128,9 @@ def build_command(
termui.warning(f"Bento for '{model_id}' already exists [{bento}]. To overwrite it pass '--overwrite'.\n")
if not get_debug_mode():
termui.echo(OPENLLM_FIGLET)
termui.echo('\n📖 Next steps:\n\n', nl=False)
termui.echo('📖 Next steps:\n', nl=False)
for instruction in response['instructions']:
termui.echo(f"* {instruction['content']}\n", nl=False)
termui.echo(f" * {instruction['content']}\n", nl=False)
if push:
BentoMLContainer.bentocloud_client.get().push_bento(
@@ -1112,7 +1147,6 @@ def build_command(
except Exception as err:
raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err
response.pop('instructions')
if get_debug_mode():
termui.echo('\n' + orjson.dumps(response).decode(), fg=None)
return response

View File

@@ -11,7 +11,6 @@ import orjson
from openllm_core._typing_compat import DictStrAny
from openllm_core.utils import get_debug_mode
from openllm_core.utils import get_quiet_mode
logger = logging.getLogger('openllm')
@@ -53,7 +52,10 @@ class JsonLog(t.TypedDict):
def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None:
echo(orjson.dumps(JsonLog(log_level=level, content=content)).decode(), fg=fg, json=True)
if get_debug_mode():
echo(content, fg=fg)
else:
echo(orjson.dumps(JsonLog(log_level=level, content=content)).decode(), fg=fg, json=True)
warning = functools.partial(log, level=Level.WARNING)
@@ -64,7 +66,7 @@ info = functools.partial(log, level=Level.INFO)
notset = functools.partial(log, level=Level.NOTSET)
def echo(text: t.Any, fg: str | None = None, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
def echo(text: t.Any, fg: str | None = None, *, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
if json:
text = orjson.loads(text)
if 'content' in text and 'log_level' in text:
@@ -77,8 +79,7 @@ def echo(text: t.Any, fg: str | None = None, _with_style: bool = True, json: boo
content = t.cast(str, text)
attrs['fg'] = fg
if not get_quiet_mode():
t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(content, **attrs)
(click.echo if not _with_style else click.secho)(content, **attrs)
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))

View File

@@ -40,6 +40,7 @@ if t.TYPE_CHECKING:
from openllm_core.utils import generate_context as generate_context
from openllm_core.utils import generate_hash_from_file as generate_hash_from_file
from openllm_core.utils import get_debug_mode as get_debug_mode
from openllm_core.utils import get_disable_warnings as get_disable_warnings
from openllm_core.utils import get_quiet_mode as get_quiet_mode
from openllm_core.utils import in_notebook as in_notebook
from openllm_core.utils import is_autoawq_available as is_autoawq_available
@@ -61,6 +62,7 @@ if t.TYPE_CHECKING:
from openllm_core.utils import resolve_user_filepath as resolve_user_filepath
from openllm_core.utils import serde as serde
from openllm_core.utils import set_debug_mode as set_debug_mode
from openllm_core.utils import set_disable_warnings as set_disable_warnings
from openllm_core.utils import set_quiet_mode as set_quiet_mode
from openllm_core.utils import validate_is_path as validate_is_path
from openllm_core.utils.serde import converter as converter