mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-12 02:20:32 -04:00
feat(engine): CTranslate2 (#698)
* chore: update instruction for dependencies Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * feat(experimental): CTranslate2 Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,6 @@ from openllm_core._typing_compat import (
|
||||
Concatenate,
|
||||
DictStrAny,
|
||||
LiteralBackend,
|
||||
LiteralQuantise,
|
||||
LiteralSerialisation,
|
||||
ParamSpec,
|
||||
get_literal_args,
|
||||
@@ -289,10 +288,10 @@ def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
|
||||
def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--dtype',
|
||||
type=click.Choice(['float16', 'float32', 'bfloat16', 'auto']),
|
||||
type=str,
|
||||
envvar='TORCH_DTYPE',
|
||||
default='auto',
|
||||
help='Optional dtype for casting tensors for running inference.',
|
||||
help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16', 'int8', 'int16']. For CTranslate2, it also accepts the following ['int8_float32', 'int8_float16', 'int8_bfloat16']",
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
@@ -341,15 +340,13 @@ def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -
|
||||
|
||||
|
||||
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
|
||||
# XXX: remove the check for __args__ once we have ggml and mlc supports
|
||||
return cli_option(
|
||||
'--backend',
|
||||
type=click.Choice(get_literal_args(LiteralBackend)[:2]),
|
||||
type=click.Choice(get_literal_args(LiteralBackend)),
|
||||
default=None,
|
||||
envvar='OPENLLM_BACKEND',
|
||||
show_envvar=True,
|
||||
help='The implementation for saving this LLM.',
|
||||
help='Runtime to use for both serialisation/inference engine.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
@@ -368,7 +365,7 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
|
||||
'--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=click.Choice(get_literal_args(LiteralQuantise)),
|
||||
type=str,
|
||||
default=None,
|
||||
envvar='OPENLLM_QUANTIZE',
|
||||
show_envvar=True,
|
||||
@@ -382,6 +379,10 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
|
||||
|
||||
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
|
||||
|
||||
- ``awq``: ``AWQ`` [AWQ: Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978)
|
||||
|
||||
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629)
|
||||
|
||||
> [!NOTE] that the model can also be served with quantized weights.
|
||||
"""
|
||||
+ (
|
||||
|
||||
@@ -85,9 +85,7 @@ def _start(
|
||||
"""
|
||||
from .entrypoint import start_command, start_grpc_command
|
||||
|
||||
os.environ['OPENLLM_BACKEND'] = openllm_core.utils.first_not_none(
|
||||
backend, default='vllm' if is_vllm_available() else 'pt'
|
||||
)
|
||||
os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
|
||||
|
||||
args: list[str] = [model_id]
|
||||
if system_message:
|
||||
|
||||
@@ -450,7 +450,7 @@ def start_command(
|
||||
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
if backend == 'pt' and not torch.cuda.is_available():
|
||||
if dtype == 'auto':
|
||||
dtype = 'float'
|
||||
elif dtype not in {'float', 'float32'} and not get_disable_warnings() and not get_quiet_mode():
|
||||
@@ -465,7 +465,7 @@ def start_command(
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
serialisation=serialisation,
|
||||
torch_dtype=dtype,
|
||||
dtype=dtype,
|
||||
)
|
||||
backend_warning(llm.__llm_backend__)
|
||||
|
||||
@@ -580,7 +580,7 @@ def start_grpc_command(
|
||||
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
if backend == 'pt' and not torch.cuda.is_available():
|
||||
if dtype == 'auto':
|
||||
dtype = 'float'
|
||||
elif dtype not in {'float', 'float32'} and not get_disable_warnings() and not get_quiet_mode():
|
||||
@@ -595,7 +595,7 @@ def start_grpc_command(
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
serialisation=serialisation,
|
||||
torch_dtype=dtype,
|
||||
dtype=dtype,
|
||||
trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'),
|
||||
)
|
||||
backend_warning(llm.__llm_backend__)
|
||||
@@ -661,14 +661,14 @@ def process_environ(
|
||||
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
|
||||
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
|
||||
'OPENLLM_SERIALIZATION': serialisation,
|
||||
'OPENLLM_BACKEND': llm.__llm_backend__,
|
||||
'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(),
|
||||
'TORCH_DTYPE': str(llm._torch_dtype).split('.')[-1],
|
||||
'BACKEND': llm.__llm_backend__,
|
||||
'DTYPE': str(llm._torch_dtype).split('.')[-1],
|
||||
'TRUST_REMOTE_CODE': str(llm.trust_remote_code),
|
||||
}
|
||||
)
|
||||
if llm.quantise:
|
||||
environ['OPENLLM_QUANTIZE'] = str(llm.quantise)
|
||||
environ['QUANTIZE'] = str(llm.quantise)
|
||||
if system_message:
|
||||
environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template:
|
||||
@@ -695,10 +695,11 @@ def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]
|
||||
|
||||
|
||||
def build_bento_instruction(llm, model_id, serialisation, adapter_map):
|
||||
cmd_name = f'openllm build {model_id}'
|
||||
cmd_name = f'openllm build {model_id} --backend {llm.__llm_backend__}'
|
||||
if llm.quantise:
|
||||
cmd_name += f' --quantize {llm.quantise}'
|
||||
cmd_name += f' --serialization {serialisation}'
|
||||
if llm.__llm_backend__ in {'pt', 'vllm'}:
|
||||
cmd_name += f' --serialization {serialisation}'
|
||||
if adapter_map is not None:
|
||||
cmd_name += ' ' + ' '.join(
|
||||
[
|
||||
@@ -1042,7 +1043,7 @@ def build_command(
|
||||
system_message=system_message,
|
||||
backend=backend,
|
||||
quantize=quantize,
|
||||
torch_dtype=dtype,
|
||||
dtype=dtype,
|
||||
serialisation=first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
|
||||
@@ -61,8 +61,8 @@ else:
|
||||
llm = openllm.LLM(
|
||||
model_args.model_id, quantize='int4', bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
model, tokenizer = llm.prepare_for_training(
|
||||
adapter_type='lora',
|
||||
model, tokenizer = llm.prepare(
|
||||
'lora',
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=16,
|
||||
|
||||
@@ -135,9 +135,7 @@ def prepare_for_int4_training(
|
||||
modules = find_all_linear_names(llm.model)
|
||||
print(f'Found {len(modules)} modules to quantize: {modules}')
|
||||
|
||||
model, tokenizer = llm.prepare_for_training(
|
||||
adapter_type='lora', use_gradient_checkpointing=gradient_checkpointing, target_modules=modules
|
||||
)
|
||||
model, tokenizer = llm.prepare('lora', use_gradient_checkpointing=gradient_checkpointing, target_modules=modules)
|
||||
|
||||
# pre-process the model by upcasting the layer norms in float 32 for
|
||||
for name, module in model.named_modules():
|
||||
|
||||
@@ -65,8 +65,8 @@ else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
llm = openllm.LLM(model_args.model_id, quantize='int8')
|
||||
model, tokenizer = llm.prepare_for_training(
|
||||
adapter_type='lora', r=16, lora_alpha=32, target_modules=['q_proj', 'v_proj'], lora_dropout=0.05, bias='none'
|
||||
model, tokenizer = llm.prepare(
|
||||
'lora', r=16, lora_alpha=32, target_modules=['q_proj', 'v_proj'], lora_dropout=0.05, bias='none'
|
||||
)
|
||||
|
||||
# ft on english_quotes
|
||||
|
||||
Reference in New Issue
Block a user