feat(engine): CTranslate2 (#698)

* chore: update instruction for dependencies

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* feat(experimental): CTranslate2

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-19 10:25:08 -05:00
committed by GitHub
parent 539f250c0f
commit 816c1ee80e
31 changed files with 945 additions and 350 deletions

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import functools
import importlib.util
import logging
import os
import types
@@ -43,6 +42,7 @@ from openllm_core.utils import (
get_disable_warnings,
get_quiet_mode,
getenv,
is_ctranslate_available,
is_peft_available,
is_vllm_available,
resolve_filepath,
@@ -165,32 +165,28 @@ class LLM(t.Generic[M, T], ReprMixin):
low_cpu_mem_usage=True,
**attrs,
):
# backward compatible
torch_dtype = attrs.pop('torch_dtype', None)
if torch_dtype is not None:
logger.warning(
'The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.'
)
dtype = torch_dtype
# fmt: off
torch_dtype = attrs.pop('torch_dtype',None) # backward compatible
if torch_dtype is not None:logger.warning('The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.');dtype=torch_dtype
_local = False
if validate_is_path(model_id):
model_id, _local = resolve_filepath(model_id), True
backend = first_not_none(backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if is_vllm_available() else 'pt')
dtype = first_not_none(getenv('dtype', default=dtype, var=['TORCH_DTYPE']), default='auto')
quantize = first_not_none(getenv('quantize', default=quantize, var=['QUANITSE']), default=None)
attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage})
if validate_is_path(model_id):model_id,_local=resolve_filepath(model_id),True
backend=first_not_none(getenv('backend',default=backend),default=self._cascade_backend())
dtype=first_not_none(getenv('dtype',default=dtype,var=['TORCH_DTYPE']),default='auto')
quantize=first_not_none(getenv('quantize',default=quantize,var=['QUANITSE']),default=None)
attrs.update({'low_cpu_mem_usage':low_cpu_mem_usage})
# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
model_attrs, tokenizer_attrs = flatten_attrs(**attrs)
if model_tag is None:
model_tag, model_version = self._make_tag_components(model_id, model_version, backend=backend)
if model_version:
model_tag = f'{model_tag}:{model_version}'
model_tag,model_version=self._make_tag_components(model_id,model_version,backend=backend)
if model_version:model_tag=f'{model_tag}:{model_version}'
# fmt: on
self.__attrs_init__(
model_id=model_id,
revision=model_version,
tag=bentoml.Tag.from_taglike(model_tag),
quantization_config=quantization_config,
quantise=quantize,
quantise=self._resolve_quantise(quantize, backend),
model_decls=args,
adapter_map=_resolve_peft_config_type(adapter_map) if adapter_map is not None else None,
serialisation=serialisation,
@@ -217,63 +213,66 @@ class LLM(t.Generic[M, T], ReprMixin):
)
self.runner.init_local(quiet=True)
# fmt: off
def _resolve_quantise(self, quantise, backend):
if backend in ('pt', 'vllm'):return quantise
if backend=='ctranslate':return self._resolve_ctranslate_quantise(quantise)
raise NotImplementedError(f"Quantisation is not supported for backend '{self.__llm_backend__}'")
def _resolve_ctranslate_quantise(self,quantise):
if quantise in {'int4', 'awq', 'gptq', 'squeezellm'}:raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'")
if quantise == 'int8':quantise='int8_float16' if self._has_gpus else 'int8_float32'
return quantise
@apply(lambda val:tuple(str.lower(i) if i else i for i in val))
def _make_tag_components(self,model_id:str,model_version:str|None,backend:str)->tuple[str,str|None]:
model_id,*maybe_revision=model_id.rsplit(':')
if len(maybe_revision)>0:
if model_version is not None:logger.warning("revision is specified within 'model_id' (%s), and 'model_version=%s' will be ignored.",maybe_revision[0],model_version)
model_version = maybe_revision[0]
if validate_is_path(model_id):model_id,model_version=resolve_filepath(model_id),first_not_none(model_version,default=generate_hash_from_file(model_id))
return f'{backend}-{normalise_model_name(model_id)}',model_version
@functools.cached_property
def _has_gpus(self):
try:
from cuda import cuda
err,*_=cuda.cuInit(0)
if err!=cuda.CUresult.CUDA_SUCCESS:raise RuntimeError('Failed to initialise CUDA runtime binding.')
err,num_gpus=cuda.cuDeviceGetCount()
if err!=cuda.CUresult.CUDA_SUCCESS:raise RuntimeError('Failed to get CUDA device count.')
return True
except (ImportError, RuntimeError):return False
@property
def _torch_dtype(self):
import torch
import transformers
if not isinstance(self.__llm_torch_dtype__, torch.dtype):
try:
hf_config = transformers.AutoConfig.from_pretrained(
self.bentomodel.path, trust_remote_code=self.trust_remote_code
)
except OpenLLMException:
hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
config_dtype = getattr(hf_config, 'torch_dtype', None)
if config_dtype is None:
config_dtype = torch.float32
if self.__llm_dtype__ == 'auto':
if config_dtype == torch.float32:
torch_dtype = torch.float16 # following common practice
else:
torch_dtype = config_dtype
import torch, transformers # noqa: I001
_map=_torch_dtype_mapping()
if not isinstance(self.__llm_torch_dtype__,torch.dtype):
try:hf_config=transformers.AutoConfig.from_pretrained(self.bentomodel.path,trust_remote_code=self.trust_remote_code)
except OpenLLMException:hf_config=transformers.AutoConfig.from_pretrained(self.model_id,trust_remote_code=self.trust_remote_code)
config_dtype=getattr(hf_config,'torch_dtype',None)
if config_dtype is None:config_dtype=torch.float32
if self.__llm_dtype__=='auto':
if config_dtype==torch.float32:torch_dtype=torch.float16
else:torch_dtype=config_dtype
else:
if self.__llm_dtype__ not in _torch_dtype_mapping():
raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'")
torch_dtype = _torch_dtype_mapping()[self.__llm_dtype__]
self.__llm_torch_dtype__ = torch_dtype
if self.__llm_dtype__ not in _map:raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'")
torch_dtype=_map[self.__llm_dtype__]
self.__llm_torch_dtype__=torch_dtype
return self.__llm_torch_dtype__
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
def _make_tag_components(self, model_id, model_version, backend) -> tuple[str, str | None]:
model_id, *maybe_revision = model_id.rsplit(':')
if len(maybe_revision) > 0:
if model_version is not None:
logger.warning(
"revision is specified within 'model_id' (%s), and 'model_version=%s' will be ignored.",
maybe_revision[0],
model_version,
)
model_version = maybe_revision[0]
if validate_is_path(model_id):
model_id, model_version = (
resolve_filepath(model_id),
first_not_none(model_version, default=generate_hash_from_file(model_id)),
)
return f'{backend}-{normalise_model_name(model_id)}', model_version
def __setattr__(self, attr, value):
if attr in _reserved_namespace:
raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
super().__setattr__(attr, value)
# fmt: off
@property
def _model_attrs(self):return {**self.import_kwargs[0],**self.__model_attrs}
@_model_attrs.setter
def _model_attrs(self, value):self.__model_attrs = value
@property
def _tokenizer_attrs(self):return {**self.import_kwargs[1],**self.__tokenizer_attrs}
def _cascade_backend(self)->LiteralBackend:
if self._has_gpus:
if is_vllm_available():return 'vllm'
elif is_ctranslate_available():return 'ctranslate' # XXX: base OpenLLM image should always include vLLM
elif is_ctranslate_available():return 'ctranslate'
else:return 'pt'
def __setattr__(self,attr,value):
if attr in _reserved_namespace:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
super().__setattr__(attr, value)
def __del__(self):del self.__llm_model__,self.__llm_tokenizer__,self.__llm_adapter_map__
@property
def __repr_keys__(self):return {'model_id','revision','backend','type'}
def __repr_args__(self):
@@ -282,10 +281,10 @@ class LLM(t.Generic[M, T], ReprMixin):
yield 'backend',self.__llm_backend__
yield 'type',self.llm_type
@property
def import_kwargs(self):import torch;return {'device_map':'auto' if torch.cuda.is_available() else None, 'torch_dtype':self._torch_dtype},{'padding_side':'left','truncation_side':'left'} # noqa: I001
def import_kwargs(self):return {'device_map':'auto' if self._has_gpus else None,'torch_dtype':self._torch_dtype},{'padding_side':'left','truncation_side':'left'}
@property
def trust_remote_code(self):
env = os.getenv('TRUST_REMOTE_CODE')
env=os.getenv('TRUST_REMOTE_CODE')
if env is not None:return str(env).upper() in ENV_VARS_TRUE_VALUES
return self.__llm_trust_remote_code__
@property
@@ -319,10 +318,6 @@ class LLM(t.Generic[M, T], ReprMixin):
@property
def identifying_params(self):return {'configuration':self.config.model_dump_json().decode(),'model_ids':orjson.dumps(self.config['model_ids']).decode(),'model_id':self.model_id}
@property
def config(self):
if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
return self.__llm_config__
@property
def tokenizer(self):
if self.__llm_tokenizer__ is None:self.__llm_tokenizer__=openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
return self.__llm_tokenizer__
@@ -330,8 +325,42 @@ class LLM(t.Generic[M, T], ReprMixin):
def runner(self):
if self.__llm_runner__ is None:self.__llm_runner__=_RunnerFactory(self)
return self.__llm_runner__
def prepare(self,adapter_type='lora',use_gradient_checking=True,**attrs):
if self.__llm_backend__!='pt':raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
from peft.mapping import get_peft_model
from peft.utils.other import prepare_model_for_kbit_training
model=get_peft_model(
prepare_model_for_kbit_training(self.model,use_gradient_checkpointing=use_gradient_checking),
self.config['fine_tune_strategies']
.get(adapter_type,self.config.make_fine_tune_config(adapter_type))
.train()
.with_config(**attrs)
.build(),
)
if DEBUG:model.print_trainable_parameters()
return model,self.tokenizer
def prepare_for_training(self,*args,**attrs):logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Please use `prepare` instead.');return self.prepare(*args,**attrs)
# fmt: on
@property
def adapter_map(self):
if not is_peft_available():
raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'")
if not self.has_adapters:
raise AttributeError('Adapter map is not available.')
assert self._adapter_map is not None
if self.__llm_adapter_map__ is None:
_map: ResolvedAdapterMap = {k: {} for k in self._adapter_map}
for adapter_type, adapter_tuple in self._adapter_map.items():
base = first_not_none(
self.config['fine_tune_strategies'].get(adapter_type),
default=self.config.make_fine_tune_config(adapter_type),
)
for adapter in adapter_tuple:
_map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
self.__llm_adapter_map__ = _map
return self.__llm_adapter_map__
@property
def model(self):
if self.__llm_model__ is None:
@@ -359,41 +388,31 @@ class LLM(t.Generic[M, T], ReprMixin):
return self.__llm_model__
@property
def adapter_map(self):
if importlib.util.find_spec('peft') is None:
raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'")
if not self.has_adapters:
raise AttributeError('Adapter map is not available.')
assert self._adapter_map is not None
if self.__llm_adapter_map__ is None:
_map: ResolvedAdapterMap = {k: {} for k in self._adapter_map}
for adapter_type, adapter_tuple in self._adapter_map.items():
base = first_not_none(
self.config['fine_tune_strategies'].get(adapter_type),
default=self.config.make_fine_tune_config(adapter_type),
)
for adapter in adapter_tuple:
_map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
self.__llm_adapter_map__ = _map
return self.__llm_adapter_map__
def config(self):
import transformers
def prepare_for_training(self, adapter_type='lora', use_gradient_checking=True, **attrs):
from peft.mapping import get_peft_model
from peft.utils.other import prepare_model_for_kbit_training
peft_config = (
self.config['fine_tune_strategies']
.get(adapter_type, self.config.make_fine_tune_config(adapter_type))
.train()
.with_config(**attrs)
.build()
)
model = get_peft_model(
prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking), peft_config
)
if DEBUG:
model.print_trainable_parameters()
return model, self.tokenizer
if self.__llm_config__ is None:
if self.__llm_backend__ == 'ctranslate':
try:
config = transformers.AutoConfig.from_pretrained(
self.bentomodel.path_of('/hf'), trust_remote_code=self.trust_remote_code
)
except OpenLLMException:
config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
for architecture in config.architectures:
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
config = openllm.AutoConfig.infer_class_from_name(
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
).model_construct_env(**self._model_attrs)
break
else:
raise OpenLLMException(
f"Failed to infer the configuration class from the given model. Make sure the model is a supported model. Supported models are: {', '.join(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE.keys())}"
)
else:
config = openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
self.__llm_config__ = config
return self.__llm_config__
async def generate(
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
@@ -476,7 +495,7 @@ def _RunnerFactory(
scheduling_strategy = CascadingResourceStrategy
backend = first_not_none(backend, os.getenv('OPENLLM_BACKEND', default=llm.__llm_backend__))
backend = first_not_none(getenv('backend', default=backend), default=llm.__llm_backend__)
models = models if models is not None else []
try:
@@ -533,7 +552,7 @@ def _RunnerFactory(
}
),
)(
runnable(backend),
runnable(llm, backend),
name=llm.runner_name,
embedded=False,
models=models,

View File

@@ -135,7 +135,7 @@ class LLM(Generic[M, T]):
def runner(self) -> Runner[M, T]: ...
@property
def adapter_map(self) -> ResolvedAdapterMap: ...
def prepare_for_training(
def prepare(
self, adapter_type: AdapterType = ..., use_gradient_checking: bool = ..., **attrs: Any
) -> Tuple[InjectedModel, T]: ...
async def generate(

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import gc
import os
import traceback
import typing as t
@@ -10,14 +9,87 @@ import bentoml
import openllm
from openllm_core._schemas import CompletionChunk, GenerationOutput, SampleLogprobs
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import first_not_none, is_vllm_available
from openllm_core.utils import first_not_none, getenv, is_ctranslate_available
__all__ = ['runnable']
def runnable(backend=None):
backend = first_not_none(backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if is_vllm_available() else 'pt')
return vLLMRunnable if backend == 'vllm' else PyTorchRunnable
def runnable(llm, backend=None):
backend = first_not_none(getenv('backend', default=backend), default=llm._cascade_backend())
if backend == 'vllm':
return vLLMRunnable
elif backend == 'pt':
return PyTorchRunnable
elif backend == 'ctranslate':
return CTranslateRunnable
else:
raise OpenLLMException(f'Unsupported backend: {backend}')
class CTranslateRunnable(bentoml.Runnable):
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu')
SUPPORTS_CPU_MULTI_THREADING = True
def __init__(self, llm):
if not is_ctranslate_available():
raise OpenLLMException('ctranslate is not installed. Please install it with `pip install "openllm[ctranslate]"`')
self.config = llm.config
self.model = llm.model
self.tokenizer = llm.tokenizer
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
if adapter_name is not None:
raise NotImplementedError('Adapter is not supported with CTranslate.')
stop_ = set()
if isinstance(stop, str) and stop != '':
stop_.add(stop)
elif isinstance(stop, t.Iterable):
stop_.update(stop)
config = self.config.model_construct_env(stop=list(stop_), **attrs)
sampling_params = dict(
max_length=config['max_new_tokens'],
min_length=config['min_length'],
sampling_topk=config['top_k'],
sampling_topp=config['top_p'],
sampling_temperature=config['temperature'],
return_log_prob=config['logprobs'] > 0,
repetition_penalty=config['repetition_penalty'],
no_repeat_ngram_size=config['no_repeat_ngram_size'],
end_token=config['stop'],
)
cumulative_logprob = 0.0
output_token_ids = list(prompt_token_ids)
input_len = len(prompt_token_ids)
async for request_output in self.model.async_generate_tokens(
self.tokenizer.convert_ids_to_tokens(prompt_token_ids), **sampling_params
):
cumulative_logprob += request_output.log_prob if config['logprobs'] else 0.0
output_token_ids.append(request_output.token_id)
text = self.tokenizer.decode(
output_token_ids[input_len:],
skip_special_tokens=True,
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True,
)
yield GenerationOutput(
prompt='',
finished=request_output.is_last,
outputs=[
CompletionChunk(
index=0,
text=text,
token_ids=output_token_ids[input_len:],
cumulative_logprob=cumulative_logprob,
finish_reason=None,
# TODO: logprobs, but seems like we don't have access to the raw logits
)
],
prompt_token_ids=prompt_token_ids,
request_id=request_id,
).model_dump_json()
class vLLMRunnable(bentoml.Runnable):
@@ -44,7 +116,7 @@ class vLLMRunnable(bentoml.Runnable):
trust_remote_code=llm.trust_remote_code,
tokenizer_mode='auto',
tensor_parallel_size=num_gpus,
dtype=str(llm._torch_dtype).split('.')[-1],
dtype=llm._torch_dtype,
quantization=quantization,
worker_use_ray=False,
engine_use_ray=False,
@@ -242,7 +314,7 @@ class PyTorchRunnable(bentoml.Runnable):
CompletionChunk(
index=0,
text=text,
token_ids=output_token_ids[input_len:],
token_ids=tmp_output_ids,
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=None,

View File

@@ -18,7 +18,7 @@ from typing import (
from bentoml import Model, Strategy, Tag
from bentoml._internal.runner.runner_handle import RunnerHandle
from openllm_core import LLMConfig
from openllm_core._typing_compat import LiteralBackend, T, overload
from openllm_core._typing_compat import LiteralBackend, M, T, overload
from ._llm import LLM
@@ -32,10 +32,16 @@ try:
except ImportError:
PreTrainedModel = Any
try:
from ctranslate2 import Generator, Translator
except ImportError:
Translator = Any
Generator = Any
Mo = TypeVar('Mo')
class _Runnable(Protocol[Mo]):
SUPPORTED_RESOURCES: Tuple[Literal['nvidia.com/gpu'], Literal['amd.com/gpu'], Literal['cpu']] = ...
SUPPORTED_RESOURCES: Tuple[Literal['nvidia.com/gpu', 'amd.com/gpu', 'cpu'], ...] = ...
SUPPORTS_CPU_MULTI_THREADING: bool = ...
config: LLMConfig = ...
model: Mo = ...
@@ -57,6 +63,10 @@ class RunnerMethod(Generic[In, Ret]): ...
@final
class vLLMRunnable(_Runnable[AsyncLLMEngine]): ...
@final
class CTranslateRunnable(_Runnable[Union[Translator, Generator]]):
tokenizer: Any
@final
class PyTorchRunnable(_Runnable[PreTrainedModel]):
tokenizer: Any
@@ -70,11 +80,15 @@ class PyTorchRunnable(_Runnable[PreTrainedModel]):
) -> AsyncGenerator[str, None]: ...
@overload
def runnable(backend: Literal['vllm']) -> Type[vLLMRunnable]: ...
def runnable(llm: LLM[M, T], backend: Literal['vllm']) -> Type[vLLMRunnable]: ...
@overload
def runnable(backend: Literal['pt']) -> Type[PyTorchRunnable]: ...
def runnable(llm: LLM[M, T], backend: Literal['pt']) -> Type[PyTorchRunnable]: ...
@overload
def runnable(backend: Optional[str] = ...) -> Type[Union[vLLMRunnable, PyTorchRunnable]]: ...
def runnable(llm: LLM[M, T], backend: Literal['ctranslate']) -> Type[CTranslateRunnable]: ...
@overload
def runnable(
llm: LLM[M, T], backend: Optional[str] = ...
) -> Type[Union[vLLMRunnable, PyTorchRunnable, CTranslateRunnable]]: ...
class Runner(Protocol[Mo, T]):
__doc__: str = ...

View File

@@ -0,0 +1,159 @@
import contextlib
import attr
from simple_di import Provide, inject
import bentoml
import openllm
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelOptions, ModelSignature
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import is_autogptq_available
_object_setattr = object.__setattr__
def get_hash(config) -> str:
_commit_hash = getattr(config, '_commit_hash', None)
if _commit_hash is None:
raise ValueError(f'Cannot find commit hash in {config}')
return _commit_hash
def patch_correct_tag(llm, config, _revision=None):
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
if llm.revision is not None:
return
if not llm.local:
try:
if _revision is None:
_revision = get_hash(config)
except ValueError:
pass
if _revision is None and llm.tag.version is not None:
_revision = llm.tag.version
if llm.tag.version is None:
# HACK: This copies the correct revision into llm.tag
_object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision))
if llm._revision is None:
_object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
def _create_metadata(llm, config, safe_serialisation, trust_remote_code, metadata=None):
if metadata is None:
metadata = {}
metadata.update({'safe_serialisation': safe_serialisation, '_framework': llm.__llm_backend__})
if llm.quantise:
metadata['_quantize'] = llm.quantise
architectures = getattr(config, 'architectures', [])
if not architectures:
if trust_remote_code:
auto_map = getattr(config, 'auto_map', {})
if not auto_map:
raise RuntimeError(
f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}'
)
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
if autoclass not in auto_map:
raise RuntimeError(
f"Given model '{llm.model_id}' is yet to be supported with 'auto_map'. OpenLLM currently only support encoder-decoders or decoders only models."
)
architectures = [auto_map[autoclass]]
else:
raise RuntimeError(
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
)
metadata.update(
{'_pretrained_class': architectures[0], '_revision': get_hash(config) if not llm.local else llm.revision}
)
return metadata
def _create_signatures(llm, signatures=None):
if signatures is None:
signatures = {}
if llm.__llm_backend__ == 'pt':
if llm.quantise == 'gptq':
if not is_autogptq_available():
raise OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
)
signatures['generate'] = {'batchable': False}
else:
signatures.update(
{
k: ModelSignature(batchable=False)
for k in (
'__call__',
'forward',
'generate',
'contrastive_search',
'greedy_search',
'sample',
'beam_search',
'beam_sample',
'group_beam_search',
'constrained_beam_search',
)
}
)
elif llm.__llm_backend__ == 'ctranslate':
if llm.config['model_type'] == 'seq2seq_lm':
non_batch_keys = {'score_file', 'translate_file'}
batch_keys = {'generate_tokens', 'score_batch', 'translate_batch', 'translate_iterable', 'score_iterable'}
else:
non_batch_keys = set()
batch_keys = {
'async_generate_tokens',
'forward_batch',
'generate_batch',
'generate_iterable',
'generate_tokens',
'score_batch',
'score_iterable',
}
signatures.update({k: ModelSignature(batchable=False) for k in non_batch_keys})
signatures.update({k: ModelSignature(batchable=True) for k in batch_keys})
return signatures
@inject
@contextlib.contextmanager
def save_model(
llm,
config,
safe_serialisation,
trust_remote_code,
module,
external_modules,
_model_store=Provide[BentoMLContainer.model_store],
_api_version='v2.1.0',
):
imported_modules = []
bentomodel = bentoml.Model.create(
llm.tag,
module=f'openllm.serialisation.{module}',
api_version=_api_version,
options=ModelOptions(),
context=openllm.utils.generate_context('openllm'),
labels=openllm.utils.generate_labels(llm),
metadata=_create_metadata(llm, config, safe_serialisation, trust_remote_code),
signatures=_create_signatures(llm),
)
with openllm.utils.analytics.set_bentoml_tracking():
try:
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
yield bentomodel, imported_modules
except Exception:
raise
else:
bentomodel.flush()
bentomodel.save(_model_store)
openllm.utils.analytics.track(
openllm.utils.analytics.ModelSaveEvent(
module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024
)
)
finally:
bentomodel.exit_cloudpickle_context(imported_modules)
return bentomodel

View File

@@ -0,0 +1,24 @@
import types
from contextlib import contextmanager
from typing import Optional, Sequence
import transformers
from bentoml import Model
from openllm_core._typing_compat import M, T
from .._llm import LLM
def get_hash(config: transformers.PretrainedConfig) -> str: ...
def patch_correct_tag(
llm: LLM[M, T], config: transformers.PretrainedConfig, _revision: Optional[str] = ...
) -> None: ...
@contextmanager
def save_model(
llm: LLM[M, T],
config: transformers.PretrainedConfig,
safe_serialisation: bool,
trust_remote_code: bool,
module: str,
external_modules: Sequence[types.ModuleType],
) -> Model: ...

View File

@@ -0,0 +1,100 @@
import importlib
import logging
import shutil
import transformers
import bentoml
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import is_ctranslate_available
from .._helpers import patch_correct_tag, save_model
from ..transformers._helpers import get_tokenizer, process_config
if not is_ctranslate_available():
raise RuntimeError(
"'ctranslate2' is required to use with backend 'ctranslate'. Install it with 'pip install \"openllm[ctranslate]\"'"
)
import ctranslate2
from ctranslate2.converters.transformers import TransformersConverter
logger = logging.getLogger(__name__)
def _get_class(llm):
return ctranslate2.Translator if llm.config['model_type'] == 'seq2seq_lm' else ctranslate2.Generator
def import_model(llm, *decls, trust_remote_code, **attrs):
(_base_decls, _base_attrs), tokenizer_attrs = llm.llm_parameters
for it in {'device_map', 'torch_dtype'}:
_base_attrs.pop(it, None) # pop out hf-specific attributes
decls = (*_base_decls, *decls)
attrs = {**_base_attrs, **attrs}
low_cpu_mem_usage = attrs.pop('low_cpu_mem_usage', True)
logger.debug(
'Note that CTranslate2 will load into memory for conversion. Refer to https://opennmt.net/CTranslate2/guides/transformers.html for more information.'
)
if not llm._local:
logger.warning(
"It is RECOMMENDED to convert '%s' to CTranslate2 format yourself to utilise CTranslate2's features, then start with `openllm start /path/to/ct2-dir`. OpenLLM will conservely apply quantization for conversion if specified.",
llm.model_id,
)
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
patch_correct_tag(llm, config)
tokenizer = get_tokenizer(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
with save_model(
llm, config, False, trust_remote_code, 'ctranslate', [importlib.import_module(tokenizer.__module__)]
) as save_metadata:
bentomodel, _ = save_metadata
if llm._local:
shutil.copytree(
llm.model_id,
bentomodel.path,
symlinks=False,
ignore=shutil.ignore_patterns('.git', 'venv', '__pycache__', '.venv'),
dirs_exist_ok=True,
)
else:
TransformersConverter(
llm.model_id,
load_as_float16=llm.quantise in ('float16', 'int8_float16'),
low_cpu_mem_usage=low_cpu_mem_usage,
trust_remote_code=trust_remote_code,
).convert(bentomodel.path, quantization=llm.quantise, force=True)
# Save the original HF configuration to hf
config.save_pretrained(bentomodel.path_of('/hf/'))
tokenizer.save_pretrained(bentomodel.path)
return bentomodel
def get(llm):
try:
model = bentoml.models.get(llm.tag)
backend = model.info.labels['backend']
if backend != llm.__llm_backend__:
raise OpenLLMException(
f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'."
)
patch_correct_tag(
llm,
transformers.AutoConfig.from_pretrained(model.path_of('/hf/'), trust_remote_code=llm.trust_remote_code),
_revision=model.info.metadata.get('_revision'),
)
return model
except Exception as err:
raise OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
def load_model(llm, *decls, **attrs):
device = 'cuda' if llm._has_gpus else 'cpu'
if llm.quantise:
compute_type = llm.quantise
elif llm.__llm_dtype__ == 'half':
compute_type = 'float16'
elif llm.__llm_dtype__ == 'float':
compute_type = 'float32'
else:
compute_type = llm.__llm_dtype__
return _get_class(llm)(llm.bentomodel.path, device=device, compute_type=compute_type)

View File

@@ -1,22 +1,18 @@
from __future__ import annotations
import importlib
import logging
import attr
import orjson
import torch
import transformers
from huggingface_hub import snapshot_download
from simple_di import Provide, inject
import bentoml
import openllm
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelOptions, ModelSignature
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import first_not_none, is_autogptq_available
from ._helpers import get_hash, infer_autoclass_from_llm, process_config
from ._helpers import get_tokenizer, infer_autoclass_from_llm, process_config
from .weights import HfIgnore
from .._helpers import patch_correct_tag, save_model
logger = logging.getLogger(__name__)
@@ -24,162 +20,56 @@ __all__ = ['import_model', 'get', 'load_model']
_object_setattr = object.__setattr__
def _patch_correct_tag(llm, config, _revision=None):
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
if llm.revision is not None:
return
if not llm.local:
try:
if _revision is None:
_revision = get_hash(config)
except ValueError:
pass
if _revision is None and llm.tag.version is not None:
_revision = llm.tag.version
if llm.tag.version is None:
_object_setattr(
llm, '_tag', attr.evolve(llm.tag, version=_revision)
) # HACK: This copies the correct revision into llm.tag
if llm._revision is None:
_object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
@inject
def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLContainer.model_store], **attrs):
_base_decls, _base_attrs = llm.llm_parameters[0]
def import_model(llm, *decls, trust_remote_code, **attrs):
(_base_decls, _base_attrs), tokenizer_attrs = llm.llm_parameters
decls = (*_base_decls, *decls)
attrs = {**_base_attrs, **attrs}
if llm._local:
logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
_patch_correct_tag(llm, config)
_, tokenizer_attrs = llm.llm_parameters
quantize = llm.quantise
safe_serialisation = openllm.utils.first_not_none(
attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors'
)
metadata = {'safe_serialisation': safe_serialisation}
if quantize:
metadata['_quantize'] = quantize
architectures = getattr(config, 'architectures', [])
if not architectures:
if trust_remote_code:
auto_map = getattr(config, 'auto_map', {})
if not auto_map:
raise RuntimeError(
f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}'
)
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
if autoclass not in auto_map:
raise RuntimeError(
f"Given model '{llm.model_id}' is yet to be supported with 'auto_map'. OpenLLM currently only support encoder-decoders or decoders only models."
)
architectures = [auto_map[autoclass]]
else:
raise RuntimeError(
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
)
metadata['_pretrained_class'] = architectures[0]
if not llm._local:
metadata['_revision'] = get_hash(config)
else:
metadata['_revision'] = llm.revision
signatures = {}
if quantize == 'gptq':
if not openllm.utils.is_autogptq_available():
raise OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
)
signatures['generate'] = {'batchable': False}
else:
patch_correct_tag(llm, config)
safe_serialisation = first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
if llm.quantise != 'gptq':
attrs['use_safetensors'] = safe_serialisation
metadata['_framework'] = llm.__llm_backend__
signatures.update(
{
k: ModelSignature(batchable=False)
for k in (
'__call__',
'forward',
'generate',
'contrastive_search',
'greedy_search',
'sample',
'beam_search',
'beam_sample',
'group_beam_search',
'constrained_beam_search',
)
}
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = None
external_modules = [importlib.import_module(tokenizer.__module__)]
imported_modules = []
bentomodel = bentoml.Model.create(
llm.tag,
module='openllm.serialisation.transformers',
api_version='v2.1.0',
options=ModelOptions(),
context=openllm.utils.generate_context(framework_name='openllm'),
labels=openllm.utils.generate_labels(llm),
metadata=metadata,
signatures=signatures,
)
with openllm.utils.analytics.set_bentoml_tracking():
try:
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
tokenizer.save_pretrained(bentomodel.path)
if llm._quantization_config or (llm.quantise and llm.quantise not in {'squeezellm', 'awq'}):
attrs['quantization_config'] = llm.quantization_config
if quantize == 'gptq':
from optimum.gptq.constants import GPTQ_CONFIG
tokenizer = get_tokenizer(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
with save_model(
llm, config, safe_serialisation, trust_remote_code, 'transformers', [importlib.import_module(tokenizer.__module__)]
) as save_metadata:
bentomodel, imported_modules = save_metadata
tokenizer.save_pretrained(bentomodel.path)
if llm._quantization_config or (llm.quantise and llm.quantise not in {'squeezellm', 'awq'}):
attrs['quantization_config'] = llm.quantization_config
if llm.quantise == 'gptq':
from optimum.gptq.constants import GPTQ_CONFIG
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
if llm._local: # possible local path
model = infer_autoclass_from_llm(llm, config).from_pretrained(
llm.model_id,
*decls,
local_files_only=True,
config=config,
trust_remote_code=trust_remote_code,
**hub_attrs,
**attrs,
)
# for trust_remote_code to work
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# we will clone the all tings into the bentomodel path without loading model into memory
snapshot_download(
llm.model_id,
local_dir=bentomodel.path,
local_dir_use_symlinks=False,
ignore_patterns=HfIgnore.ignore_patterns(llm),
)
except Exception:
raise
else:
bentomodel.flush() # type: ignore[no-untyped-call]
bentomodel.save(_model_store)
openllm.utils.analytics.track(
openllm.utils.analytics.ModelSaveEvent(
module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024
)
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
if llm._local: # possible local path
model = infer_autoclass_from_llm(llm, config).from_pretrained(
llm.model_id,
*decls,
local_files_only=True,
config=config,
trust_remote_code=trust_remote_code,
**hub_attrs,
**attrs,
)
# for trust_remote_code to work
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# we will clone the all tings into the bentomodel path without loading model into memory
snapshot_download(
llm.model_id,
local_dir=bentomodel.path,
local_dir_use_symlinks=False,
ignore_patterns=HfIgnore.ignore_patterns(llm),
)
finally:
bentomodel.exit_cloudpickle_context(imported_modules)
return bentomodel
@@ -191,7 +81,7 @@ def get(llm):
raise OpenLLMException(
f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'."
)
_patch_correct_tag(
patch_correct_tag(
llm,
transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code),
_revision=model.info.metadata.get('_revision'),
@@ -226,7 +116,7 @@ def load_model(llm, *decls, **attrs):
if '_quantize' in llm.bentomodel.info.metadata:
_quantise = llm.bentomodel.info.metadata['_quantize']
if _quantise == 'gptq':
if not openllm.utils.is_autogptq_available():
if not is_autogptq_available():
raise OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
)

View File

@@ -11,11 +11,13 @@ from openllm_core.utils import get_disable_warnings, get_quiet_mode
logger = logging.getLogger(__name__)
def get_hash(config: transformers.PretrainedConfig) -> str:
_commit_hash = getattr(config, '_commit_hash', None)
if _commit_hash is None:
raise ValueError(f'Cannot find commit hash in {config}')
return _commit_hash
def get_tokenizer(model_id_or_path, trust_remote_code, **attrs):
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_id_or_path, trust_remote_code=trust_remote_code, **attrs
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any):

View File

@@ -34,6 +34,7 @@ from openllm_core.utils import (
is_autogptq_available as is_autogptq_available,
is_bentoml_available as is_bentoml_available,
is_bitsandbytes_available as is_bitsandbytes_available,
is_ctranslate_available as is_ctranslate_available,
is_grpc_available as is_grpc_available,
is_jupyter_available as is_jupyter_available,
is_jupytext_available as is_jupytext_available,

View File

@@ -17,7 +17,6 @@ from openllm_core._typing_compat import (
Concatenate,
DictStrAny,
LiteralBackend,
LiteralQuantise,
LiteralSerialisation,
ParamSpec,
get_literal_args,
@@ -289,10 +288,10 @@ def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--dtype',
type=click.Choice(['float16', 'float32', 'bfloat16', 'auto']),
type=str,
envvar='TORCH_DTYPE',
default='auto',
help='Optional dtype for casting tensors for running inference.',
help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16', 'int8', 'int16']. For CTranslate2, it also accepts the following ['int8_float32', 'int8_float16', 'int8_bfloat16']",
**attrs,
)(f)
@@ -341,15 +340,13 @@ def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
# XXX: remove the check for __args__ once we have ggml and mlc supports
return cli_option(
'--backend',
type=click.Choice(get_literal_args(LiteralBackend)[:2]),
type=click.Choice(get_literal_args(LiteralBackend)),
default=None,
envvar='OPENLLM_BACKEND',
show_envvar=True,
help='The implementation for saving this LLM.',
help='Runtime to use for both serialisation/inference engine.',
**attrs,
)(f)
@@ -368,7 +365,7 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
'--quantise',
'--quantize',
'quantize',
type=click.Choice(get_literal_args(LiteralQuantise)),
type=str,
default=None,
envvar='OPENLLM_QUANTIZE',
show_envvar=True,
@@ -382,6 +379,10 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
- ``awq``: ``AWQ`` [AWQ: Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978)
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629)
> [!NOTE] that the model can also be served with quantized weights.
"""
+ (

View File

@@ -85,9 +85,7 @@ def _start(
"""
from .entrypoint import start_command, start_grpc_command
os.environ['OPENLLM_BACKEND'] = openllm_core.utils.first_not_none(
backend, default='vllm' if is_vllm_available() else 'pt'
)
os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
args: list[str] = [model_id]
if system_message:

View File

@@ -450,7 +450,7 @@ def start_command(
import torch
if not torch.cuda.is_available():
if backend == 'pt' and not torch.cuda.is_available():
if dtype == 'auto':
dtype = 'float'
elif dtype not in {'float', 'float32'} and not get_disable_warnings() and not get_quiet_mode():
@@ -465,7 +465,7 @@ def start_command(
adapter_map=adapter_map,
quantize=quantize,
serialisation=serialisation,
torch_dtype=dtype,
dtype=dtype,
)
backend_warning(llm.__llm_backend__)
@@ -580,7 +580,7 @@ def start_grpc_command(
import torch
if not torch.cuda.is_available():
if backend == 'pt' and not torch.cuda.is_available():
if dtype == 'auto':
dtype = 'float'
elif dtype not in {'float', 'float32'} and not get_disable_warnings() and not get_quiet_mode():
@@ -595,7 +595,7 @@ def start_grpc_command(
adapter_map=adapter_map,
quantize=quantize,
serialisation=serialisation,
torch_dtype=dtype,
dtype=dtype,
trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'),
)
backend_warning(llm.__llm_backend__)
@@ -661,14 +661,14 @@ def process_environ(
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
'OPENLLM_SERIALIZATION': serialisation,
'OPENLLM_BACKEND': llm.__llm_backend__,
'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(),
'TORCH_DTYPE': str(llm._torch_dtype).split('.')[-1],
'BACKEND': llm.__llm_backend__,
'DTYPE': str(llm._torch_dtype).split('.')[-1],
'TRUST_REMOTE_CODE': str(llm.trust_remote_code),
}
)
if llm.quantise:
environ['OPENLLM_QUANTIZE'] = str(llm.quantise)
environ['QUANTIZE'] = str(llm.quantise)
if system_message:
environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template:
@@ -695,10 +695,11 @@ def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]
def build_bento_instruction(llm, model_id, serialisation, adapter_map):
cmd_name = f'openllm build {model_id}'
cmd_name = f'openllm build {model_id} --backend {llm.__llm_backend__}'
if llm.quantise:
cmd_name += f' --quantize {llm.quantise}'
cmd_name += f' --serialization {serialisation}'
if llm.__llm_backend__ in {'pt', 'vllm'}:
cmd_name += f' --serialization {serialisation}'
if adapter_map is not None:
cmd_name += ' ' + ' '.join(
[
@@ -1042,7 +1043,7 @@ def build_command(
system_message=system_message,
backend=backend,
quantize=quantize,
torch_dtype=dtype,
dtype=dtype,
serialisation=first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),

View File

@@ -61,8 +61,8 @@ else:
llm = openllm.LLM(
model_args.model_id, quantize='int4', bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.float16
)
model, tokenizer = llm.prepare_for_training(
adapter_type='lora',
model, tokenizer = llm.prepare(
'lora',
lora_alpha=16,
lora_dropout=0.1,
r=16,

View File

@@ -135,9 +135,7 @@ def prepare_for_int4_training(
modules = find_all_linear_names(llm.model)
print(f'Found {len(modules)} modules to quantize: {modules}')
model, tokenizer = llm.prepare_for_training(
adapter_type='lora', use_gradient_checkpointing=gradient_checkpointing, target_modules=modules
)
model, tokenizer = llm.prepare('lora', use_gradient_checkpointing=gradient_checkpointing, target_modules=modules)
# pre-process the model by upcasting the layer norms in float 32 for
for name, module in model.named_modules():

View File

@@ -65,8 +65,8 @@ else:
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
llm = openllm.LLM(model_args.model_id, quantize='int8')
model, tokenizer = llm.prepare_for_training(
adapter_type='lora', r=16, lora_alpha=32, target_modules=['q_proj', 'v_proj'], lora_dropout=0.05, bias='none'
model, tokenizer = llm.prepare(
'lora', r=16, lora_alpha=32, target_modules=['q_proj', 'v_proj'], lora_dropout=0.05, bias='none'
)
# ft on english_quotes