mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-07 22:33:28 -05:00
chore(releases): remove deadcode
Signed-off-by: Aaron Pham (mbp16) <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,366 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import functools, logging, os, typing as t
|
||||
import bentoml, openllm, click, inflection, click_option_group as cog
|
||||
from bentoml_cli.utils import BentoMLCommandGroup
|
||||
from click import shell_completion as sc
|
||||
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core._typing_compat import (
|
||||
Concatenate,
|
||||
DictStrAny,
|
||||
LiteralBackend,
|
||||
LiteralSerialisation,
|
||||
ParamSpec,
|
||||
AnyCallable,
|
||||
get_literal_args,
|
||||
)
|
||||
from openllm_core.utils import DEBUG, compose, dantic, resolve_user_filepath
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec('P')
|
||||
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
|
||||
_AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar('FC', bound=t.Union[_AnyCallable, click.Command])
|
||||
|
||||
|
||||
def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [
|
||||
sc.CompletionItem(str(it.tag), help='Bento')
|
||||
for it in bentoml.list()
|
||||
if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})
|
||||
]
|
||||
|
||||
|
||||
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [
|
||||
sc.CompletionItem(inflection.dasherize(it), help='Model')
|
||||
for it in openllm.CONFIG_MAPPING
|
||||
if it.startswith(incomplete)
|
||||
]
|
||||
|
||||
|
||||
def parse_config_options(
|
||||
config: LLMConfig,
|
||||
server_timeout: int,
|
||||
workers_per_resource: float,
|
||||
device: t.Tuple[str, ...] | None,
|
||||
cors: bool,
|
||||
environ: DictStrAny,
|
||||
) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
|
||||
_bentoml_config_options_opts = [
|
||||
'tracing.sample_rate=1.0',
|
||||
'api_server.max_runner_connections=25',
|
||||
f'runners."llm-{config["start_name"]}-runner".batching.max_batch_size=128',
|
||||
f'api_server.traffic.timeout={server_timeout}',
|
||||
f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}',
|
||||
f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
|
||||
]
|
||||
if device:
|
||||
if len(device) > 1:
|
||||
_bentoml_config_options_opts.extend([
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}'
|
||||
for idx, dev in enumerate(device)
|
||||
])
|
||||
else:
|
||||
_bentoml_config_options_opts.append(
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]'
|
||||
)
|
||||
if cors:
|
||||
_bentoml_config_options_opts.extend([
|
||||
'api_server.http.cors.enabled=true',
|
||||
'api_server.http.cors.access_control_allow_origins="*"',
|
||||
])
|
||||
_bentoml_config_options_opts.extend([
|
||||
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"'
|
||||
for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
|
||||
])
|
||||
_bentoml_config_options_env += ' ' if _bentoml_config_options_env else '' + ' '.join(_bentoml_config_options_opts)
|
||||
environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env
|
||||
if DEBUG:
|
||||
logger.debug('Setting BENTOML_CONFIG_OPTIONS=%s', _bentoml_config_options_env)
|
||||
return environ
|
||||
|
||||
|
||||
_adapter_mapping_key = 'adapter_map'
|
||||
|
||||
|
||||
def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...] | None) -> None:
|
||||
if not value:
|
||||
return None
|
||||
if _adapter_mapping_key not in ctx.params:
|
||||
ctx.params[_adapter_mapping_key] = {}
|
||||
for v in value:
|
||||
adapter_id, *adapter_name = v.rsplit(':', maxsplit=1)
|
||||
# try to resolve the full path if users pass in relative,
|
||||
# currently only support one level of resolve path with current directory
|
||||
try:
|
||||
adapter_id = resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
name = adapter_name[0] if len(adapter_name) > 0 else 'default'
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = name
|
||||
return None
|
||||
|
||||
|
||||
def optimization_decorator(fn: FC, *, factory=click, _eager=True) -> FC | list[AnyCallable]:
|
||||
shared = [
|
||||
dtype_option(factory=factory),
|
||||
revision_option(factory=factory), #
|
||||
backend_option(factory=factory),
|
||||
quantize_option(factory=factory), #
|
||||
serialisation_option(factory=factory),
|
||||
]
|
||||
if not _eager:
|
||||
return shared
|
||||
return compose(*shared)(fn)
|
||||
|
||||
|
||||
# NOTE: A list of bentoml option that is not needed for parsing.
|
||||
# NOTE: User shouldn't set '--working-dir', as OpenLLM will setup this.
|
||||
# NOTE: production is also deprecated
|
||||
_IGNORED_OPTIONS = {'working_dir', 'production', 'protocol_version'}
|
||||
|
||||
|
||||
def parse_serve_args() -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
|
||||
from bentoml_cli.cli import cli
|
||||
|
||||
group = cog.optgroup.group(
|
||||
'Start a HTTP server options', help='Related to serving the model [synonymous to `bentoml serve-http`]'
|
||||
)
|
||||
|
||||
def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]:
|
||||
serve_command = cli.commands['serve']
|
||||
# The first variable is the argument bento
|
||||
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
|
||||
serve_options = [
|
||||
p
|
||||
for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS]
|
||||
if p.name not in _IGNORED_OPTIONS
|
||||
]
|
||||
for options in reversed(serve_options):
|
||||
attrs = options.to_info_dict()
|
||||
# we don't need param_type_name, since it should all be options
|
||||
attrs.pop('param_type_name')
|
||||
# name is not a valid args
|
||||
attrs.pop('name')
|
||||
# type can be determine from default value
|
||||
attrs.pop('type')
|
||||
param_decls = (*attrs.pop('opts'), *attrs.pop('secondary_opts'))
|
||||
f = cog.optgroup.option(*param_decls, **attrs)(f)
|
||||
return group(f)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def start_decorator(fn: FC) -> FC:
|
||||
composed = compose(
|
||||
cog.optgroup.group(
|
||||
'LLM Options',
|
||||
help="""The following options are related to running LLM Server as well as optimization options.
|
||||
|
||||
OpenLLM supports running model k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
|
||||
|
||||
The following are either in our roadmap or currently being worked on:
|
||||
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
""",
|
||||
),
|
||||
*optimization_decorator(fn, factory=cog.optgroup, _eager=False),
|
||||
cog.optgroup.option(
|
||||
'--device',
|
||||
type=dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help='Assign GPU devices (if available)',
|
||||
show_envvar=True,
|
||||
),
|
||||
adapter_id_option(factory=cog.optgroup),
|
||||
)
|
||||
|
||||
return composed(fn)
|
||||
|
||||
|
||||
def parse_device_callback(
|
||||
_: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None
|
||||
) -> t.Tuple[str, ...] | None:
|
||||
if value is None:
|
||||
return value
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == 'all':
|
||||
return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
|
||||
|
||||
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
|
||||
"""General ``@click`` decorator with some sauce.
|
||||
|
||||
This decorator extends the default ``@click.option`` plus a factory option and factory attr to
|
||||
provide type-safe click.option or click.argument wrapper for all compatible factory.
|
||||
"""
|
||||
factory = attrs.pop('factory', click)
|
||||
factory_attr = attrs.pop('attr', 'option')
|
||||
if factory_attr != 'argument':
|
||||
attrs.setdefault('help', 'General option for OpenLLM CLI.')
|
||||
|
||||
def decorator(f: FC | None) -> FC:
|
||||
callback = getattr(factory, factory_attr, None)
|
||||
if callback is None:
|
||||
raise ValueError(f'Factory {factory} has no attribute {factory_attr}.')
|
||||
return t.cast(FC, callback(*param_decls, **attrs)(f) if f is not None else callback(*param_decls, **attrs))
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
cli_option = functools.partial(_click_factory_type, attr='option')
|
||||
cli_argument = functools.partial(_click_factory_type, attr='argument')
|
||||
|
||||
|
||||
def adapter_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--adapter-id',
|
||||
default=None,
|
||||
help='Optional name or path for given LoRA adapter',
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--cors/--no-cors',
|
||||
show_default=True,
|
||||
default=False,
|
||||
envvar='OPENLLM_CORS',
|
||||
show_envvar=True,
|
||||
help='Enable CORS for the server.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
|
||||
|
||||
def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--dtype',
|
||||
type=str,
|
||||
envvar='TORCH_DTYPE',
|
||||
default='auto',
|
||||
help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16', 'int8', 'int16']",
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--model-id',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_MODEL_ID',
|
||||
show_envvar=True,
|
||||
help='Optional model_id name or path for (fine-tune) weight.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def revision_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--revision',
|
||||
'--model-version',
|
||||
'model_version',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help='Optional model revision to save for this model. It will be inferred automatically from model-id.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--backend',
|
||||
type=click.Choice(get_literal_args(LiteralBackend)),
|
||||
default=None,
|
||||
envvar='OPENLLM_BACKEND',
|
||||
show_envvar=True,
|
||||
help='Runtime to use for both serialisation/inference engine.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_argument(
|
||||
'model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]),
|
||||
required=required,
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=str,
|
||||
default=None,
|
||||
envvar='OPENLLM_QUANTIZE',
|
||||
show_envvar=True,
|
||||
help="""Dynamic quantization for running this LLM.
|
||||
|
||||
The following quantization strategies are supported:
|
||||
|
||||
- ``int8``: ``LLM.int8`` for [8-bit](https://arxiv.org/abs/2208.07339) quantization.
|
||||
|
||||
- ``int4``: ``SpQR`` for [4-bit](https://arxiv.org/abs/2306.03078) quantization.
|
||||
|
||||
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
|
||||
|
||||
- ``awq``: ``AWQ`` [AWQ: Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978)
|
||||
|
||||
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629)
|
||||
|
||||
> [!NOTE] that the model can also be served with quantized weights.
|
||||
"""
|
||||
+ (
|
||||
"""
|
||||
> [!NOTE] that this will set the mode for serving within deployment."""
|
||||
if build
|
||||
else ''
|
||||
),
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--serialisation',
|
||||
'--serialization',
|
||||
'serialisation',
|
||||
type=click.Choice(get_literal_args(LiteralSerialisation)),
|
||||
default=None,
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help="""Serialisation format for save/load LLM.
|
||||
|
||||
Currently the following strategies are supported:
|
||||
|
||||
- ``safetensors``: This will use safetensors format, which is synonymous to ``safe_serialization=True``.
|
||||
|
||||
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
|
||||
|
||||
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
|
||||
""",
|
||||
**attrs,
|
||||
)(f)
|
||||
@@ -1,276 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import itertools, logging, os, re, subprocess, sys, typing as t, bentoml, openllm_core, orjson
|
||||
from simple_di import Provide, inject
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm_core._typing_compat import LiteralSerialisation
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import WARNING_ENV_VAR, codegen, first_not_none, get_disable_warnings, is_vllm_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core._typing_compat import LiteralBackend, LiteralQuantise, LiteralString
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _start(
|
||||
model_id: str,
|
||||
timeout: int = 30,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
cors: bool = False,
|
||||
__test__: bool = False,
|
||||
**_: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
"""Python API to start a LLM server. These provides one-to-one mapping to CLI arguments.
|
||||
|
||||
For all additional arguments, pass it as string to ``additional_args``. For example, if you want to
|
||||
pass ``--port 5001``, you can pass ``additional_args=["--port", "5001"]``
|
||||
|
||||
> [!NOTE] This will create a blocking process, so if you use this API, you can create a running sub thread
|
||||
> to start the server instead of blocking the main thread.
|
||||
|
||||
``openllm.start`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI interaction.
|
||||
|
||||
Args:
|
||||
model_id: The model id to start this LLMServer
|
||||
timeout: The server timeout
|
||||
workers_per_resource: Number of workers per resource assigned.
|
||||
See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy)
|
||||
for more information. By default, this is set to 1.
|
||||
|
||||
> [!NOTE] ``--workers-per-resource`` will also accept the following strategies:
|
||||
> - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
|
||||
> - ``conserved``: This will determine the number of available GPU resources, and only assign
|
||||
> one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is
|
||||
> equivalent to ``--workers-per-resource 0.25``.
|
||||
device: Assign GPU devices (if available) to this LLM. By default, this is set to ``None``. It also accepts 'all'
|
||||
argument to assign all available GPUs to this LLM.
|
||||
quantize: Quantize the model weights. This is only applicable for PyTorch models.
|
||||
Possible quantisation strategies:
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
- int4: Quantize the model with 4bit (bitsandbytes required)
|
||||
- gptq: Quantize the model with GPTQ (auto-gptq required)
|
||||
cors: Whether to enable CORS for this LLM. By default, this is set to ``False``.
|
||||
adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``.
|
||||
backend: The backend to use for this LLM. By default, this is set to ``pt``.
|
||||
additional_args: Additional arguments to pass to ``openllm start``.
|
||||
"""
|
||||
from .entrypoint import start_command
|
||||
|
||||
os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
|
||||
args: list[str] = [model_id]
|
||||
if timeout:
|
||||
args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource:
|
||||
args.extend([
|
||||
'--workers-per-resource',
|
||||
str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource,
|
||||
])
|
||||
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'):
|
||||
args.extend(['--device', ','.join(device)])
|
||||
if quantize:
|
||||
args.extend(['--quantize', str(quantize)])
|
||||
if cors:
|
||||
args.append('--cors')
|
||||
if adapter_map:
|
||||
args.extend(
|
||||
list(
|
||||
itertools.chain.from_iterable([['--adapter-id', f"{k}{':' + v if v else ''}"] for k, v in adapter_map.items()])
|
||||
)
|
||||
)
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
if __test__:
|
||||
args.append('--return-process')
|
||||
|
||||
cmd = start_command
|
||||
return cmd.main(args=args, standalone_mode=False)
|
||||
|
||||
|
||||
@inject
|
||||
def _build(
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
bento_version: str | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
dockerfile_template: str | None = None,
|
||||
overwrite: bool = False,
|
||||
push: bool = False,
|
||||
force_push: bool = False,
|
||||
containerize: bool = False,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
) -> bentoml.Bento:
|
||||
"""Package a LLM into a BentoLLM.
|
||||
|
||||
The LLM will be built into a BentoService with the following structure:
|
||||
if ``quantize`` is passed, it will instruct the model to be quantized dynamically during serving time.
|
||||
|
||||
``openllm.build`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as ``openllm build`` CLI.
|
||||
|
||||
Args:
|
||||
model_id: The model id to build this BentoLLM
|
||||
model_version: Optional model version for this given LLM
|
||||
bento_version: Optional bento veresion for this given BentoLLM
|
||||
quantize: Quantize the model weights. This is only applicable for PyTorch models.
|
||||
Possible quantisation strategies:
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
- int4: Quantize the model with 4bit (bitsandbytes required)
|
||||
- gptq: Quantize the model with GPTQ (auto-gptq required)
|
||||
adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``.
|
||||
build_ctx: The build context to use for building BentoLLM. By default, it sets to current directory.
|
||||
enable_features: Additional OpenLLM features to be included with this BentoLLM.
|
||||
dockerfile_template: The dockerfile template to use for building BentoLLM. See https://docs.bentoml.com/en/latest/guides/containerization.html#dockerfile-template.
|
||||
overwrite: Whether to overwrite the existing BentoLLM. By default, this is set to ``False``.
|
||||
push: Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first.
|
||||
containerize: Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'.
|
||||
Note that 'containerize' and 'push' are mutually exclusive
|
||||
container_registry: Container registry to choose the base OpenLLM container image to build from. Default to ECR.
|
||||
serialisation: Serialisation for saving models. Default to 'safetensors', which is equivalent to `safe_serialization=True`
|
||||
additional_args: Additional arguments to pass to ``openllm build``.
|
||||
bento_store: Optional BentoStore for saving this BentoLLM. Default to the default BentoML local store.
|
||||
|
||||
Returns:
|
||||
``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from openllm.serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
args: list[str] = [
|
||||
sys.executable,
|
||||
'-m',
|
||||
'openllm',
|
||||
'build',
|
||||
model_id,
|
||||
'--machine',
|
||||
'--quiet',
|
||||
'--serialisation',
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
]
|
||||
if quantize:
|
||||
args.extend(['--quantize', quantize])
|
||||
if containerize and push:
|
||||
raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
|
||||
if push:
|
||||
args.extend(['--push'])
|
||||
if containerize:
|
||||
args.extend(['--containerize'])
|
||||
if build_ctx:
|
||||
args.extend(['--build-ctx', build_ctx])
|
||||
if enable_features:
|
||||
args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if overwrite:
|
||||
args.append('--overwrite')
|
||||
if adapter_map:
|
||||
args.extend([f"--adapter-id={k}{':' + v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version:
|
||||
args.extend(['--model-version', model_version])
|
||||
if bento_version:
|
||||
args.extend(['--bento-version', bento_version])
|
||||
if dockerfile_template:
|
||||
args.extend(['--dockerfile-template', dockerfile_template])
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
if force_push:
|
||||
args.append('--force-push')
|
||||
|
||||
current_disable_warning = get_disable_warnings()
|
||||
os.environ[WARNING_ENV_VAR] = str(True)
|
||||
try:
|
||||
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error("Exception caught while building Bento for '%s'", model_id, exc_info=e)
|
||||
if e.stderr:
|
||||
raise OpenLLMException(e.stderr.decode('utf-8')) from None
|
||||
raise OpenLLMException(str(e)) from None
|
||||
matched = re.match(r'__object__:(\{.*\})$', output.decode('utf-8').strip())
|
||||
if matched is None:
|
||||
raise ValueError(
|
||||
f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
)
|
||||
os.environ[WARNING_ENV_VAR] = str(current_disable_warning)
|
||||
try:
|
||||
result = orjson.loads(matched.group(1))
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
) from e
|
||||
return bentoml.get(result['tag'], _bento_store=bento_store)
|
||||
|
||||
|
||||
def _import_model(
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: t.Sequence[str] | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""Import a LLM into local store.
|
||||
|
||||
> [!NOTE]
|
||||
> If ``quantize`` is passed, the model weights will be saved as quantized weights. You should
|
||||
> only use this option if you want the weight to be quantized by default. Note that OpenLLM also
|
||||
> support on-demand quantisation during initial startup.
|
||||
|
||||
``openllm.import_model`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI ``openllm import``.
|
||||
|
||||
> [!NOTE]
|
||||
> ``openllm.start`` will automatically invoke ``openllm.import_model`` under the hood.
|
||||
|
||||
Args:
|
||||
model_id: required model id for this given LLM
|
||||
model_version: Optional model version for this given LLM
|
||||
backend: The backend to use for this LLM. By default, this is set to ``pt``.
|
||||
quantize: Quantize the model weights. This is only applicable for PyTorch models.
|
||||
Possible quantisation strategies:
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
- int4: Quantize the model with 4bit (bitsandbytes required)
|
||||
- gptq: Quantize the model with GPTQ (auto-gptq required)
|
||||
serialisation: Type of model format to save to local store. If set to 'safetensors', then OpenLLM will save model using safetensors. Default behaviour is similar to ``safe_serialization=False``.
|
||||
additional_args: Additional arguments to pass to ``openllm import``.
|
||||
|
||||
Returns:
|
||||
``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from .entrypoint import import_command
|
||||
|
||||
args = [model_id, '--quiet']
|
||||
if backend is not None:
|
||||
args.extend(['--backend', backend])
|
||||
if model_version is not None:
|
||||
args.extend(['--model-version', str(model_version)])
|
||||
if quantize is not None:
|
||||
args.extend(['--quantize', quantize])
|
||||
if serialisation is not None:
|
||||
args.extend(['--serialisation', serialisation])
|
||||
if additional_args is not None:
|
||||
args.extend(additional_args)
|
||||
return import_command.main(args=args, standalone_mode=False)
|
||||
|
||||
|
||||
def _list_models() -> dict[str, t.Any]:
|
||||
"""List all available models within the local store."""
|
||||
from .entrypoint import models_command
|
||||
|
||||
return models_command.main(args=['--quiet'], standalone_mode=False)
|
||||
|
||||
|
||||
start, build, import_model, list_models = (
|
||||
codegen.gen_sdk(_start),
|
||||
codegen.gen_sdk(_build),
|
||||
codegen.gen_sdk(_import_model),
|
||||
codegen.gen_sdk(_list_models),
|
||||
)
|
||||
__all__ = ['build', 'import_model', 'list_models', 'start']
|
||||
5
openllm-python/src/openllm_cli/entrypoint.py
Normal file
5
openllm-python/src/openllm_cli/entrypoint.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# mainly for backward compatible entrypoint.
|
||||
from _openllm_tiny._entrypoint import cli as cli
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
@@ -1,16 +0,0 @@
|
||||
"""OpenLLM CLI Extension.
|
||||
|
||||
The following directory contains all possible extensions for OpenLLM CLI
|
||||
For adding new extension, just simply name that ext to `<name_ext>.py` and define
|
||||
a ``click.command()`` with the following format:
|
||||
|
||||
```python
|
||||
import click
|
||||
|
||||
@click.command(<name_ext>)
|
||||
...
|
||||
def cli(...): # <- this is important here, it should always name CLI in order for the extension resolver to know how to import this extensions.
|
||||
```
|
||||
|
||||
NOTE: Make sure to keep this file blank such that it won't mess with the import order.
|
||||
"""
|
||||
@@ -1,43 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import shutil
|
||||
import subprocess
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import psutil
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm_cli import termui
|
||||
from openllm_cli._factory import bento_complete_envvar, machine_option
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
|
||||
@click.command('dive_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(
|
||||
ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
|
||||
) -> str | None:
|
||||
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
ctx.fail(f'Bento {bento} not found. Make sure to call `openllm build` first.')
|
||||
if 'bundler' not in bentomodel.info.labels or bentomodel.info.labels['bundler'] != 'openllm.bundle':
|
||||
ctx.fail(
|
||||
f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness."
|
||||
)
|
||||
if machine:
|
||||
return bentomodel.path
|
||||
# copy and paste this into a new shell
|
||||
if psutil.WINDOWS:
|
||||
subprocess.check_call([shutil.which('dir') or 'dir'], cwd=bentomodel.path)
|
||||
else:
|
||||
subprocess.check_call([shutil.which('ls') or 'ls', '-Rrthla'], cwd=bentomodel.path)
|
||||
ctx.exit(0)
|
||||
@@ -1,53 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
from bentoml._internal.bento.bento import BentoInfo
|
||||
from bentoml._internal.bento.build_config import DockerOptions
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.container.generate import generate_containerfile
|
||||
from openllm_cli import termui
|
||||
from openllm_cli._factory import bento_complete_envvar
|
||||
from openllm_core.utils import converter
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
|
||||
@click.command(
|
||||
'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.'
|
||||
)
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str:
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
ctx.fail(f'Bento {bento} not found. Make sure to call `openllm build` first.')
|
||||
# The logic below are similar to bentoml._internal.container.construct_containerfile
|
||||
with open(bentomodel.path_of('bento.yaml'), 'r') as f:
|
||||
options = BentoInfo.from_yaml_file(f)
|
||||
# NOTE: dockerfile_template is already included in the
|
||||
# Dockerfile inside bento, and it is not relevant to
|
||||
# construct_containerfile. Hence it is safe to set it to None here.
|
||||
# See https://github.com/bentoml/BentoML/issues/3399.
|
||||
docker_attrs = converter.unstructure(options.docker)
|
||||
# NOTE: if users specify a dockerfile_template, we will
|
||||
# save it to /env/docker/Dockerfile.template. This is necessary
|
||||
# for the reconstruction of the Dockerfile.
|
||||
if 'dockerfile_template' in docker_attrs and docker_attrs['dockerfile_template'] is not None:
|
||||
docker_attrs['dockerfile_template'] = 'env/docker/Dockerfile.template'
|
||||
doc = generate_containerfile(
|
||||
docker=DockerOptions(**docker_attrs),
|
||||
build_ctx=bentomodel.path,
|
||||
conda=options.conda,
|
||||
bento_fs=bentomodel._fs,
|
||||
enable_buildkit=True,
|
||||
add_header=True,
|
||||
)
|
||||
termui.echo(doc, fg='white')
|
||||
return bentomodel.path
|
||||
@@ -1,179 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import string
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
from bentoml_cli.utils import opt_callback
|
||||
|
||||
import openllm
|
||||
from openllm_cli import termui
|
||||
from openllm_cli._factory import model_complete_envvar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptFormatter(string.Formatter):
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0:
|
||||
raise ValueError('Positional arguments are not supported')
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(
|
||||
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
|
||||
) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras:
|
||||
raise KeyError(f'Extra params passed: {extras}')
|
||||
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]:
|
||||
return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
|
||||
|
||||
default_formatter = PromptFormatter()
|
||||
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class PromptTemplate:
|
||||
template: str
|
||||
_input_variables: t.Sequence[str] = attr.field(init=False)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
self._input_variables = default_formatter.extract_template_variables(self.template)
|
||||
|
||||
def with_options(self, **attrs: t.Any) -> PromptTemplate:
|
||||
prompt_variables = {key: '{' + key + '}' if key not in attrs else attrs[key] for key in self._input_variables}
|
||||
o = attr.evolve(self, template=self.template.format(**prompt_variables))
|
||||
_object_setattr(o, '_input_variables', default_formatter.extract_template_variables(o.template))
|
||||
return o
|
||||
|
||||
def format(self, **attrs: t.Any) -> str:
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in self._input_variables}
|
||||
try:
|
||||
return self.template.format(**prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
|
||||
) from None
|
||||
|
||||
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('model_id', shell_complete=model_complete_envvar)
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@click.option('--prompt-template-file', type=click.File(), default=None)
|
||||
@click.option('--chat-template-file', type=click.File(), default=None)
|
||||
@click.option('--system-message', type=str, default=None)
|
||||
@click.option(
|
||||
'--add-generation-prompt/--no-add-generation-prompt',
|
||||
default=False,
|
||||
help='See https://huggingface.co/docs/transformers/main/chat_templating#what-template-should-i-use. This only applicable if model-id is a HF model_id',
|
||||
)
|
||||
@click.option(
|
||||
'--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_message='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]',
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(
|
||||
ctx: click.Context,
|
||||
/,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
chat_template_file: t.IO[t.Any] | None,
|
||||
system_message: str | None,
|
||||
add_generation_prompt: bool,
|
||||
_memoized: dict[str, t.Any],
|
||||
**_: t.Any,
|
||||
) -> str | None:
|
||||
"""Helpers for generating prompts.
|
||||
|
||||
\b
|
||||
It accepts remote HF model_ids as well as model name passed to `openllm start`.
|
||||
|
||||
If you pass in a HF model_id, then it will use the tokenizer to generate the prompt.
|
||||
|
||||
```bash
|
||||
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there"
|
||||
```
|
||||
|
||||
If you need change the prompt template, you can create the template file that contains the jina2 template through `--chat-template-file`
|
||||
See https://huggingface.co/docs/transformers/main/chat_templating#templates-for-chat-models for more details.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there" --chat-template-file template.jinja2
|
||||
```
|
||||
|
||||
\b
|
||||
|
||||
If you pass a model name, then it will use OpenLLM configuration to generate the prompt.
|
||||
Note that this is mainly for utilities, as OpenLLM won't use these prompts to format for you.
|
||||
|
||||
\b
|
||||
```bash
|
||||
openllm get-prompt mistral "Hello there"
|
||||
"""
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
|
||||
if prompt_template_file and chat_template_file:
|
||||
ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.')
|
||||
|
||||
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(
|
||||
inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()
|
||||
)
|
||||
if model_id in acceptable:
|
||||
logger.warning(
|
||||
'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n'
|
||||
)
|
||||
config = openllm.AutoConfig.for_model(model_id)
|
||||
template = prompt_template_file.read() if prompt_template_file is not None else config.template
|
||||
system_message = system_message or config.system_message
|
||||
|
||||
try:
|
||||
formatted = (
|
||||
PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
|
||||
)
|
||||
except RuntimeError as err:
|
||||
logger.debug('Exception caught while formatting prompt: %s', err)
|
||||
ctx.fail(str(err))
|
||||
else:
|
||||
import transformers
|
||||
|
||||
trust_remote_code = openllm.utils.check_bool_env('TRUST_REMOTE_CODE', False)
|
||||
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
if chat_template_file is not None:
|
||||
chat_template_file = chat_template_file.read()
|
||||
if system_message is None:
|
||||
logger.warning('system-message is not provided, using default infer from the model architecture.\n')
|
||||
for architecture in config.architectures:
|
||||
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
|
||||
system_message = (
|
||||
openllm.AutoConfig.infer_class_from_name(
|
||||
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
|
||||
)
|
||||
.model_construct_env()
|
||||
.system_message
|
||||
)
|
||||
break
|
||||
else:
|
||||
ctx.fail(
|
||||
f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message'
|
||||
)
|
||||
messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}]
|
||||
formatted = tokenizer.apply_chat_template(
|
||||
messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False
|
||||
)
|
||||
|
||||
termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
@@ -1,34 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
from openllm_cli import termui
|
||||
|
||||
|
||||
@click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context) -> None:
|
||||
"""List available bentos built by OpenLLM."""
|
||||
mapping = {
|
||||
k: [
|
||||
{
|
||||
'tag': str(b.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(b.path)),
|
||||
'models': [
|
||||
{'tag': str(m.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(m.path))}
|
||||
for m in (bentoml.models.get(_.tag) for _ in b.info.models)
|
||||
],
|
||||
}
|
||||
for b in tuple(i for i in bentoml.list() if all(k in i.info.labels for k in {'start_name', 'bundler'}))
|
||||
if b.info.labels['start_name'] == k
|
||||
]
|
||||
for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
}
|
||||
mapping = {k: v for k, v in mapping.items() if v}
|
||||
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
@@ -1,49 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
from openllm_cli import termui
|
||||
from openllm_cli._factory import model_complete_envvar, model_name_argument
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
|
||||
@click.command('list_models', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@model_name_argument(required=False, shell_complete=model_complete_envvar)
|
||||
def cli(model_name: str | None) -> DictStrAny:
|
||||
"""List available models in local store to be used with OpenLLM."""
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
i
|
||||
for i in bentoml.models.list()
|
||||
if 'framework' in i.info.labels
|
||||
and i.info.labels['framework'] == 'openllm'
|
||||
and 'model_name' in i.info.labels
|
||||
and i.info.labels['model_name'] == k
|
||||
]
|
||||
for k in models
|
||||
}
|
||||
if model_name is not None:
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
i
|
||||
for i in v
|
||||
if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)
|
||||
]
|
||||
for k, v in ids_in_local_store.items()
|
||||
}
|
||||
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
|
||||
local_models = {
|
||||
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val]
|
||||
for k, val in ids_in_local_store.items()
|
||||
}
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return local_models
|
||||
@@ -1,108 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import importlib.machinery
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import jupytext
|
||||
import nbformat
|
||||
import yaml
|
||||
|
||||
from openllm_cli import playground, termui
|
||||
from openllm_core.utils import is_jupyter_available, is_jupytext_available, is_notebook_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_notebook_metadata() -> DictStrAny:
|
||||
with open(os.path.join(os.path.dirname(playground.__file__), '_meta.yml'), 'r') as f:
|
||||
content = yaml.safe_load(f)
|
||||
if not all('description' in k for k in content.values()):
|
||||
raise ValueError("Invalid metadata file. All entries must have a 'description' key.")
|
||||
return content
|
||||
|
||||
|
||||
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('output-dir', default=None, required=False)
|
||||
@click.option(
|
||||
'--port',
|
||||
envvar='JUPYTER_PORT',
|
||||
show_envvar=True,
|
||||
show_default=True,
|
||||
default=8888,
|
||||
help='Default port for Jupyter server',
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
"""OpenLLM Playground.
|
||||
|
||||
A collections of notebooks to explore the capabilities of OpenLLM.
|
||||
This includes notebooks for fine-tuning, inference, and more.
|
||||
|
||||
All of the script available in the playground can also be run directly as a Python script:
|
||||
For example:
|
||||
|
||||
\b
|
||||
```bash
|
||||
python -m openllm.playground.falcon_tuned --help
|
||||
```
|
||||
|
||||
\b
|
||||
> [!NOTE]
|
||||
> This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"'
|
||||
"""
|
||||
if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available():
|
||||
raise RuntimeError(
|
||||
"Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'"
|
||||
)
|
||||
metadata = load_notebook_metadata()
|
||||
_temp_dir = False
|
||||
if output_dir is None:
|
||||
_temp_dir = True
|
||||
output_dir = tempfile.mkdtemp(prefix='openllm-playground-')
|
||||
else:
|
||||
os.makedirs(os.path.abspath(os.path.expandvars(os.path.expanduser(output_dir))), exist_ok=True)
|
||||
|
||||
termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue')
|
||||
for module in pkgutil.iter_modules(playground.__path__):
|
||||
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')):
|
||||
logger.debug(
|
||||
'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module'
|
||||
)
|
||||
continue
|
||||
if not isinstance(module.module_finder, importlib.machinery.FileFinder):
|
||||
continue
|
||||
termui.echo('Generating notebook for: ' + module.name, fg='magenta')
|
||||
markdown_cell = nbformat.v4.new_markdown_cell(metadata[module.name]['description'])
|
||||
f = jupytext.read(os.path.join(module.module_finder.path, module.name + '.py'))
|
||||
f.cells.insert(0, markdown_cell)
|
||||
jupytext.write(f, os.path.join(output_dir, module.name + '.ipynb'), fmt='notebook')
|
||||
try:
|
||||
subprocess.check_output([
|
||||
sys.executable,
|
||||
'-m',
|
||||
'jupyter',
|
||||
'notebook',
|
||||
'--notebook-dir',
|
||||
output_dir,
|
||||
'--port',
|
||||
str(port),
|
||||
'--no-browser',
|
||||
'--debug',
|
||||
])
|
||||
except subprocess.CalledProcessError as e:
|
||||
termui.echo(e.output, fg='red')
|
||||
raise click.ClickException(f'Failed to start a jupyter server:\n{e}') from None
|
||||
except KeyboardInterrupt:
|
||||
termui.echo('\nShutting down Jupyter server...', fg='yellow')
|
||||
if _temp_dir:
|
||||
termui.echo('Note: You can access the generated notebooks in: ' + output_dir, fg='blue')
|
||||
ctx.exit(0)
|
||||
@@ -1,14 +0,0 @@
|
||||
This folder represents a playground and source of truth for a features
|
||||
development around OpenLLM
|
||||
|
||||
```bash
|
||||
openllm playground
|
||||
```
|
||||
|
||||
## Usage for developing this module
|
||||
|
||||
Write a python script that can be used by `jupytext` to convert it to a jupyter
|
||||
notebook via `openllm playground`
|
||||
|
||||
Make sure to add the new file and its documentation into
|
||||
[`_meta.yml`](./_meta.yml).
|
||||
@@ -1,36 +0,0 @@
|
||||
features:
|
||||
description: |
|
||||
## General introduction to OpenLLM.
|
||||
|
||||
This script will demo a few features from OpenLLM:
|
||||
|
||||
- Usage of Auto class abstraction and run prediction with `generate`
|
||||
- Ability to send per-requests parameters
|
||||
- Runner integration with BentoML
|
||||
opt_tuned:
|
||||
description: |
|
||||
## Fine tuning OPT
|
||||
|
||||
This script demonstrate how one can easily fine tune OPT
|
||||
with [LoRa](https://arxiv.org/abs/2106.09685) and in int8 with bitsandbytes.
|
||||
|
||||
It is based on one of the Peft examples fine tuning script.
|
||||
It requires at least one GPU to be available, so make sure to have it.
|
||||
falcon_tuned:
|
||||
description: |
|
||||
## Fine tuning Falcon
|
||||
|
||||
This script demonstrate how one can fine tune Falcon using [QLoRa](https://arxiv.org/pdf/2305.14314.pdf),
|
||||
[trl](https://github.com/lvwerra/trl).
|
||||
|
||||
It is trained using OpenAssistant's Guanaco [dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco)
|
||||
|
||||
It requires at least one GPU to be available, so make sure to have it.
|
||||
llama2_qlora:
|
||||
description: |
|
||||
## Fine tuning LlaMA 2
|
||||
This script demonstrate how one can fine tune Falcon using LoRA with [trl](https://github.com/lvwerra/trl)
|
||||
|
||||
It is trained using the [Dolly datasets](https://huggingface.co/datasets/databricks/databricks-dolly-15k)
|
||||
|
||||
It requires at least one GPU to be available.
|
||||
@@ -1,95 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
# Make sure to have at least one GPU to run this script
|
||||
|
||||
openllm.utils.configure_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# On notebook, make sure to install the following
|
||||
# ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git
|
||||
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
|
||||
DEFAULT_MODEL_ID = 'ybelkada/falcon-7b-sharded-bf16'
|
||||
DATASET_NAME = 'timdettmers/openassistant-guanaco'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
gradient_accumulation_steps: int = dataclasses.field(default=4)
|
||||
optim: str = dataclasses.field(default='paged_adamw_32bit')
|
||||
save_steps: int = dataclasses.field(default=10)
|
||||
warmup_steps: int = dataclasses.field(default=10)
|
||||
max_steps: int = dataclasses.field(default=500)
|
||||
logging_steps: int = dataclasses.field(default=10)
|
||||
learning_rate: float = dataclasses.field(default=2e-4)
|
||||
max_grad_norm: float = dataclasses.field(default=0.3)
|
||||
warmup_ratio: float = dataclasses.field(default=0.03)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
group_by_length: bool = dataclasses.field(default=True)
|
||||
lr_scheduler_type: str = dataclasses.field(default='constant')
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'falcon'))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
max_sequence_length: int = dataclasses.field(default=512)
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
llm = openllm.LLM(
|
||||
model_args.model_id, quantize='int4', bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
model, tokenizer = llm.prepare(
|
||||
'lora',
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=16,
|
||||
bias='none',
|
||||
target_modules=['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h'],
|
||||
)
|
||||
model.config.use_cache = False
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset = load_dataset(DATASET_NAME, split='train')
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=model_args.max_sequence_length,
|
||||
tokenizer=tokenizer,
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
)
|
||||
|
||||
# upcast layernorm in float32 for more stable training
|
||||
for name, module in trainer.model.named_modules():
|
||||
if 'norm' in name:
|
||||
module = module.to(torch.float32)
|
||||
|
||||
trainer.train()
|
||||
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
@@ -1,44 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
openllm.utils.configure_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_NEW_TOKENS = 384
|
||||
|
||||
Q = 'Answer the following question, step by step:\n{q}\nA:'
|
||||
question = 'What is the meaning of life?'
|
||||
|
||||
|
||||
async def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('question', default=question)
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
args = parser.parse_args(args=[question])
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
llm = openllm.LLM[t.Any, t.Any]('facebook/opt-2.7b')
|
||||
prompt = Q.format(q=args.question)
|
||||
|
||||
logger.info('-' * 50, "Running with 'generate()'", '-' * 50)
|
||||
res = await llm.generate(prompt)
|
||||
logger.info('=' * 10, 'Response:', res)
|
||||
|
||||
logger.info('-' * 50, "Running with 'generate()' with per-requests argument", '-' * 50)
|
||||
res = await llm.generate(prompt, max_new_tokens=MAX_NEW_TOKENS)
|
||||
logger.info('=' * 10, 'Response:', res)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _mp_fn(index: t.Any): # type: ignore
|
||||
# For xla_spawn (TPUs)
|
||||
asyncio.run(main())
|
||||
@@ -1,236 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import peft
|
||||
|
||||
# Make sure to have at least one GPU to run this script
|
||||
|
||||
openllm.utils.configure_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# On notebook, make sure to install the following
|
||||
# ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git
|
||||
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from random import randint, randrange
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
# COPIED FROM https://github.com/artidoro/qlora/blob/main/qlora.py
|
||||
def find_all_linear_names(model):
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, bnb.nn.Linear4bit):
|
||||
names = name.split('.')
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
if 'lm_head' in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove('lm_head')
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
# Change this to the local converted path if you don't have access to the meta-llama model
|
||||
DEFAULT_MODEL_ID = 'meta-llama/Llama-2-7b-hf'
|
||||
# change this to 'main' if you want to use the latest llama
|
||||
DEFAULT_MODEL_VERSION = '335a02887eb6684d487240bbc28b5699298c3135'
|
||||
DATASET_NAME = 'databricks/databricks-dolly-15k'
|
||||
|
||||
|
||||
def format_dolly(sample):
|
||||
instruction = f"### Instruction\n{sample['instruction']}"
|
||||
context = f"### Context\n{sample['context']}" if len(sample['context']) > 0 else None
|
||||
response = f"### Answer\n{sample['response']}"
|
||||
# join all the parts together
|
||||
prompt = '\n\n'.join([i for i in [instruction, context, response] if i is not None])
|
||||
return prompt
|
||||
|
||||
|
||||
# template dataset to add prompt to each sample
|
||||
def template_dataset(sample, tokenizer):
|
||||
sample['text'] = f'{format_dolly(sample)}{tokenizer.eos_token}'
|
||||
return sample
|
||||
|
||||
|
||||
# empty list to save remainder from batches to use in next batch
|
||||
remainder = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}
|
||||
|
||||
|
||||
def chunk(sample, chunk_length=2048):
|
||||
# define global remainder variable to save remainder from batches to use in next batch
|
||||
global remainder
|
||||
# Concatenate all texts and add remainder from previous batch
|
||||
concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
|
||||
concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
|
||||
# get total number of tokens for batch
|
||||
batch_total_length = len(concatenated_examples[next(iter(sample.keys()))])
|
||||
|
||||
# get max number of chunks for batch
|
||||
if batch_total_length >= chunk_length:
|
||||
batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
|
||||
|
||||
# Split by chunks of max_len.
|
||||
result = {
|
||||
k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
# add remainder to global variable for next batch
|
||||
remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
|
||||
# prepare labels
|
||||
result['labels'] = result['input_ids'].copy()
|
||||
return result
|
||||
|
||||
|
||||
def prepare_datasets(tokenizer, dataset_name=DATASET_NAME):
|
||||
# Load dataset from the hub
|
||||
dataset = load_dataset(dataset_name, split='train')
|
||||
|
||||
print(f'dataset size: {len(dataset)}')
|
||||
print(dataset[randrange(len(dataset))])
|
||||
|
||||
# apply prompt template per sample
|
||||
dataset = dataset.map(partial(template_dataset, tokenizer=tokenizer), remove_columns=list(dataset.features))
|
||||
# print random sample
|
||||
print('Sample from dolly-v2 ds:', dataset[randint(0, len(dataset))]['text'])
|
||||
|
||||
# tokenize and chunk dataset
|
||||
lm_dataset = dataset.map(
|
||||
lambda sample: tokenizer(sample['text']), batched=True, remove_columns=list(dataset.features)
|
||||
).map(partial(chunk, chunk_length=2048), batched=True)
|
||||
|
||||
# Print total number of samples
|
||||
print(f'Total number of samples: {len(lm_dataset)}')
|
||||
return lm_dataset
|
||||
|
||||
|
||||
def prepare_for_int4_training(
|
||||
model_id: str, model_version: str | None = None, gradient_checkpointing: bool = True, bf16: bool = True
|
||||
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
llm = openllm.LLM(
|
||||
model_id,
|
||||
revision=model_version,
|
||||
quantize='int4',
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
use_cache=not gradient_checkpointing,
|
||||
device_map='auto',
|
||||
)
|
||||
print('Model summary:', llm.model)
|
||||
|
||||
# get lora target modules
|
||||
modules = find_all_linear_names(llm.model)
|
||||
print(f'Found {len(modules)} modules to quantize: {modules}')
|
||||
|
||||
model, tokenizer = llm.prepare('lora', use_gradient_checkpointing=gradient_checkpointing, target_modules=modules)
|
||||
|
||||
# pre-process the model by upcasting the layer norms in float 32 for
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
if bf16:
|
||||
module = module.to(torch.bfloat16)
|
||||
if 'norm' in name:
|
||||
module = module.to(torch.float32)
|
||||
if 'lm_head' in name or 'embed_tokens' in name:
|
||||
if hasattr(module, 'weight'):
|
||||
if bf16 and module.weight.dtype == torch.float32:
|
||||
module = module.to(torch.bfloat16)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=1)
|
||||
gradient_checkpointing: bool = dataclasses.field(default=True)
|
||||
bf16: bool = dataclasses.field(default=torch.cuda.get_device_capability()[0] == 8)
|
||||
learning_rate: float = dataclasses.field(default=5e-5)
|
||||
num_train_epochs: int = dataclasses.field(default=3)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
report_to: str = dataclasses.field(default='none')
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'llama'))
|
||||
save_strategy: str = dataclasses.field(default='no')
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
model_version: str = dataclasses.field(default=DEFAULT_MODEL_VERSION)
|
||||
seed: int = dataclasses.field(default=42)
|
||||
merge_weights: bool = dataclasses.field(default=False)
|
||||
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
model_args, training_rags = ModelArguments(), TrainingArguments()
|
||||
else:
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(
|
||||
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
|
||||
)
|
||||
|
||||
# import the model first hand
|
||||
openllm.import_model(model_id=model_args.model_id, model_version=model_args.model_version)
|
||||
|
||||
|
||||
def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
import peft
|
||||
|
||||
transformers.set_seed(model_args.seed)
|
||||
|
||||
model, tokenizer = prepare_for_int4_training(
|
||||
model_args.model_id, gradient_checkpointing=training_args.gradient_checkpointing, bf16=training_args.bf16
|
||||
)
|
||||
datasets = prepare_datasets(tokenizer)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
train_dataset=datasets,
|
||||
data_collator=transformers.default_data_collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
if model_args.merge_weights:
|
||||
# note that this will requires larger GPU as we will load the whole model into memory
|
||||
|
||||
# merge adapter weights with base model and save
|
||||
# save int4 model
|
||||
trainer.model.save_pretrained(training_args.output_dir, safe_serialization=False)
|
||||
|
||||
# gc mem
|
||||
del model, trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = peft.AutoPeftModelForCausalLM.from_pretrained(
|
||||
training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16
|
||||
)
|
||||
# merge lora with base weights and save
|
||||
model = model.merge_and_unload()
|
||||
model.save_pretrained(
|
||||
os.path.join(os.getcwd(), 'outputs', 'merged_llama_lora'), safe_serialization=True, max_shard_size='2GB'
|
||||
)
|
||||
else:
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
|
||||
|
||||
train_loop(model_args, training_args)
|
||||
@@ -1,81 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
# Make sure to have at least one GPU to run this script
|
||||
|
||||
openllm.utils.configure_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# On notebook, make sure to install the following
|
||||
# ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from peft import PeftModel
|
||||
|
||||
DEFAULT_MODEL_ID = 'facebook/opt-6.7b'
|
||||
|
||||
|
||||
def load_trainer(
|
||||
model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments
|
||||
):
|
||||
return transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=dataset_dict['train'],
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
gradient_accumulation_steps: int = dataclasses.field(default=4)
|
||||
warmup_steps: int = dataclasses.field(default=10)
|
||||
max_steps: int = dataclasses.field(default=50)
|
||||
learning_rate: float = dataclasses.field(default=3e-4)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'opt'))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
llm = openllm.LLM(model_args.model_id, quantize='int8')
|
||||
model, tokenizer = llm.prepare(
|
||||
'lora', r=16, lora_alpha=32, target_modules=['q_proj', 'v_proj'], lora_dropout=0.05, bias='none'
|
||||
)
|
||||
|
||||
# ft on english_quotes
|
||||
data = load_dataset('Abirate/english_quotes')
|
||||
data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)
|
||||
|
||||
trainer = load_trainer(model, tokenizer, data, training_args)
|
||||
model.config.use_cache = False # silence just for warning, reenable for inference later
|
||||
|
||||
trainer.train()
|
||||
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
@@ -1,90 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
|
||||
from openllm_core._typing_compat import DictStrAny, TypedDict
|
||||
from openllm_core.utils import get_debug_mode
|
||||
|
||||
logger = logging.getLogger('openllm')
|
||||
from _openllm_tiny import _termui
|
||||
|
||||
|
||||
class Level(enum.IntEnum):
|
||||
NOTSET = logging.DEBUG
|
||||
DEBUG = logging.DEBUG
|
||||
INFO = logging.INFO
|
||||
WARNING = logging.WARNING
|
||||
ERROR = logging.ERROR
|
||||
CRITICAL = logging.CRITICAL
|
||||
|
||||
@property
|
||||
def color(self) -> str | None:
|
||||
return {
|
||||
Level.NOTSET: None,
|
||||
Level.DEBUG: 'cyan',
|
||||
Level.INFO: 'green',
|
||||
Level.WARNING: 'yellow',
|
||||
Level.ERROR: 'red',
|
||||
Level.CRITICAL: 'red',
|
||||
}[self]
|
||||
|
||||
@classmethod
|
||||
def from_logging_level(cls, level: int) -> Level:
|
||||
return {
|
||||
logging.DEBUG: Level.DEBUG,
|
||||
logging.INFO: Level.INFO,
|
||||
logging.WARNING: Level.WARNING,
|
||||
logging.ERROR: Level.ERROR,
|
||||
logging.CRITICAL: Level.CRITICAL,
|
||||
}[level]
|
||||
def __dir__():
|
||||
return dir(_termui)
|
||||
|
||||
|
||||
class JsonLog(TypedDict):
|
||||
log_level: Level
|
||||
content: str
|
||||
|
||||
|
||||
def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None:
|
||||
if get_debug_mode():
|
||||
echo(content, fg=fg)
|
||||
else:
|
||||
echo(orjson.dumps(JsonLog(log_level=level, content=content)).decode(), fg=fg, json=True)
|
||||
|
||||
|
||||
warning = functools.partial(log, level=Level.WARNING)
|
||||
error = functools.partial(log, level=Level.ERROR)
|
||||
critical = functools.partial(log, level=Level.CRITICAL)
|
||||
debug = functools.partial(log, level=Level.DEBUG)
|
||||
info = functools.partial(log, level=Level.INFO)
|
||||
notset = functools.partial(log, level=Level.NOTSET)
|
||||
|
||||
|
||||
def echo(text: t.Any, fg: str | None = None, *, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
|
||||
if json:
|
||||
text = orjson.loads(text)
|
||||
if 'content' in text and 'log_level' in text:
|
||||
content = text['content']
|
||||
fg = Level.from_logging_level(text['log_level']).color
|
||||
else:
|
||||
content = orjson.dumps(text).decode()
|
||||
fg = Level.INFO.color if not get_debug_mode() else Level.DEBUG.color
|
||||
else:
|
||||
content = t.cast(str, text)
|
||||
attrs['fg'] = fg
|
||||
|
||||
(click.echo if not _with_style else click.secho)(content, **attrs)
|
||||
|
||||
|
||||
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
|
||||
CONTEXT_SETTINGS: DictStrAny = {
|
||||
'help_option_names': ['-h', '--help'],
|
||||
'max_content_width': COLUMNS,
|
||||
'token_normalize_func': inflection.underscore,
|
||||
}
|
||||
__all__ = ['COLUMNS', 'CONTEXT_SETTINGS', 'Level', 'critical', 'debug', 'echo', 'error', 'info', 'log', 'warning']
|
||||
def __getattr__(name):
|
||||
return getattr(_termui, name)
|
||||
|
||||
Reference in New Issue
Block a user