mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-06 13:52:21 -05:00
fix(base-image): update base image to include cuda for now (#720)
* fix(base-image): update base image to include cuda for now Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * fix: build core and client on release images Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: cleanup style changes Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import functools, logging, os, warnings
|
||||
import typing as t
|
||||
import attr, inflection, orjson
|
||||
import bentoml, openllm
|
||||
import functools, logging, os, warnings, typing as t
|
||||
import attr, inflection, orjson, bentoml, openllm
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
from openllm_core._typing_compat import (
|
||||
AdapterMap,
|
||||
@@ -20,9 +18,9 @@ from openllm_core._typing_compat import (
|
||||
from openllm_core.exceptions import MissingDependencyError
|
||||
from openllm_core.utils import (
|
||||
DEBUG,
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
ReprMixin,
|
||||
apply,
|
||||
check_bool_env,
|
||||
codegen,
|
||||
first_not_none,
|
||||
flatten_attrs,
|
||||
@@ -142,31 +140,33 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
|
||||
# NOTE: If you are here to see how generate_iterator and generate works, see above.
|
||||
# The below are mainly for internal implementation that you don't have to worry about.
|
||||
# fmt: off
|
||||
|
||||
_model_id:str
|
||||
_revision:t.Optional[str]
|
||||
_quantization_config:t.Optional[t.Union[transformers.BitsAndBytesConfig,transformers.GPTQConfig,transformers.AwqConfig]]
|
||||
_model_id: str
|
||||
_revision: t.Optional[str]
|
||||
_quantization_config: t.Optional[
|
||||
t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]
|
||||
]
|
||||
_quantise: t.Optional[LiteralQuantise]
|
||||
_model_decls:TupleAny
|
||||
__model_attrs:DictStrAny
|
||||
__tokenizer_attrs:DictStrAny
|
||||
_tag:bentoml.Tag
|
||||
_adapter_map:t.Optional[AdapterMap]
|
||||
_serialisation:LiteralSerialisation
|
||||
_local:bool
|
||||
_max_model_len:t.Optional[int]
|
||||
_model_decls: TupleAny
|
||||
__model_attrs: DictStrAny
|
||||
__tokenizer_attrs: DictStrAny
|
||||
_tag: bentoml.Tag
|
||||
_adapter_map: t.Optional[AdapterMap]
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_max_model_len: t.Optional[int]
|
||||
|
||||
__llm_dtype__: t.Union[LiteralDtype,t.Literal['auto', 'half', 'float']]='auto'
|
||||
__llm_torch_dtype__:'torch.dtype'=None
|
||||
__llm_config__:t.Optional[LLMConfig]=None
|
||||
__llm_backend__:LiteralBackend=None
|
||||
__llm_quantization_config__:t.Optional[t.Union[transformers.BitsAndBytesConfig,transformers.GPTQConfig,transformers.AwqConfig]]=None
|
||||
__llm_runner__:t.Optional[Runner[M, T]]=None
|
||||
__llm_model__:t.Optional[M]=None
|
||||
__llm_tokenizer__:t.Optional[T]=None
|
||||
__llm_adapter_map__:t.Optional[ResolvedAdapterMap]=None
|
||||
__llm_trust_remote_code__:bool=False
|
||||
__llm_dtype__: t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']] = 'auto'
|
||||
__llm_torch_dtype__: 'torch.dtype' = None
|
||||
__llm_config__: t.Optional[LLMConfig] = None
|
||||
__llm_backend__: LiteralBackend = None
|
||||
__llm_quantization_config__: t.Optional[
|
||||
t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]
|
||||
] = None
|
||||
__llm_runner__: t.Optional[Runner[M, T]] = None
|
||||
__llm_model__: t.Optional[M] = None
|
||||
__llm_tokenizer__: t.Optional[T] = None
|
||||
__llm_adapter_map__: t.Optional[ResolvedAdapterMap] = None
|
||||
__llm_trust_remote_code__: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -188,26 +188,34 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
_eager=True,
|
||||
**attrs,
|
||||
):
|
||||
torch_dtype=attrs.pop('torch_dtype',None) # backward compatible
|
||||
if torch_dtype is not None:warnings.warns('The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.',DeprecationWarning,stacklevel=3);dtype=torch_dtype
|
||||
torch_dtype = attrs.pop('torch_dtype', None) # backward compatible
|
||||
if torch_dtype is not None:
|
||||
warnings.warns(
|
||||
'The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.',
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
dtype = torch_dtype
|
||||
_local = False
|
||||
if validate_is_path(model_id):model_id,_local=resolve_filepath(model_id),True
|
||||
backend=first_not_none(getenv('backend',default=backend),default=self._cascade_backend())
|
||||
dtype=first_not_none(getenv('dtype',default=dtype,var=['TORCH_DTYPE']),default='auto')
|
||||
quantize=first_not_none(getenv('quantize',default=quantize,var=['QUANITSE']),default=None)
|
||||
attrs.update({'low_cpu_mem_usage':low_cpu_mem_usage})
|
||||
if validate_is_path(model_id):
|
||||
model_id, _local = resolve_filepath(model_id), True
|
||||
backend = first_not_none(getenv('backend', default=backend), default=self._cascade_backend())
|
||||
dtype = first_not_none(getenv('dtype', default=dtype, var=['TORCH_DTYPE']), default='auto')
|
||||
quantize = first_not_none(getenv('quantize', default=quantize, var=['QUANITSE']), default=None)
|
||||
attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage})
|
||||
# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
|
||||
model_attrs,tokenizer_attrs=flatten_attrs(**attrs)
|
||||
model_attrs, tokenizer_attrs = flatten_attrs(**attrs)
|
||||
if model_tag is None:
|
||||
model_tag,model_version=self._make_tag_components(model_id,model_version,backend=backend)
|
||||
if model_version:model_tag=f'{model_tag}:{model_version}'
|
||||
model_tag, model_version = self._make_tag_components(model_id, model_version, backend=backend)
|
||||
if model_version:
|
||||
model_tag = f'{model_tag}:{model_version}'
|
||||
|
||||
self.__attrs_init__(
|
||||
model_id=model_id,
|
||||
revision=model_version,
|
||||
tag=bentoml.Tag.from_taglike(model_tag),
|
||||
quantization_config=quantization_config,
|
||||
quantise=getattr(self._Quantise,backend)(self,quantize),
|
||||
quantise=getattr(self._Quantise, backend)(self, quantize),
|
||||
model_decls=args,
|
||||
adapter_map=convert_peft_config_type(adapter_map) if adapter_map is not None else None,
|
||||
serialisation=serialisation,
|
||||
@@ -220,143 +228,248 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
llm_config__=llm_config,
|
||||
llm_trust_remote_code__=trust_remote_code,
|
||||
)
|
||||
|
||||
if _eager:
|
||||
try:
|
||||
model=bentoml.models.get(self.tag)
|
||||
model = bentoml.models.get(self.tag)
|
||||
except bentoml.exceptions.NotFound:
|
||||
model=openllm.serialisation.import_model(self,trust_remote_code=self.trust_remote_code)
|
||||
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
|
||||
# resolve the tag
|
||||
self._tag=model.tag
|
||||
if not _eager and embedded:raise RuntimeError("Embedded mode is not supported when '_eager' is False.")
|
||||
if embedded:logger.warning('Models will be loaded into memory. NOT RECOMMENDED in production and SHOULD ONLY used for development.');self.runner.init_local(quiet=True)
|
||||
self._tag = model.tag
|
||||
if not _eager and embedded:
|
||||
raise RuntimeError("Embedded mode is not supported when '_eager' is False.")
|
||||
if embedded:
|
||||
logger.warning(
|
||||
'NOT RECOMMENDED in production and SHOULD ONLY used for development (Loading into current memory).'
|
||||
)
|
||||
self.runner.init_local(quiet=True)
|
||||
|
||||
class _Quantise:
|
||||
@staticmethod
|
||||
def pt(llm:LLM,quantise=None):return quantise
|
||||
@staticmethod
|
||||
def vllm(llm:LLM,quantise=None):return quantise
|
||||
@staticmethod
|
||||
def ctranslate(llm:LLM,quantise=None):
|
||||
if quantise in {'int4','awq','gptq','squeezellm'}:raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'")
|
||||
if quantise=='int8':quantise='int8_float16' if llm._has_gpus else 'int8_float32'
|
||||
def pt(llm: LLM, quantise=None):
|
||||
return quantise
|
||||
@apply(lambda val:tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self,model_id:str,model_version:str|None,backend:str)->tuple[str,str|None]:
|
||||
model_id,*maybe_revision=model_id.rsplit(':')
|
||||
if len(maybe_revision)>0:
|
||||
if model_version is not None:logger.warning("revision is specified within 'model_id' (%s), and 'model_version=%s' will be ignored.",maybe_revision[0],model_version)
|
||||
|
||||
@staticmethod
|
||||
def vllm(llm: LLM, quantise=None):
|
||||
return quantise
|
||||
|
||||
@staticmethod
|
||||
def ctranslate(llm: LLM, quantise=None):
|
||||
if quantise in {'int4', 'awq', 'gptq', 'squeezellm'}:
|
||||
raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'")
|
||||
if quantise == 'int8':
|
||||
quantise = 'int8_float16' if llm._has_gpus else 'int8_float32'
|
||||
return quantise
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self, model_id: str, model_version: str | None, backend: str) -> tuple[str, str | None]:
|
||||
model_id, *maybe_revision = model_id.rsplit(':')
|
||||
if len(maybe_revision) > 0:
|
||||
if model_version is not None:
|
||||
logger.warning(
|
||||
"revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version
|
||||
)
|
||||
model_version = maybe_revision[0]
|
||||
if validate_is_path(model_id):model_id,model_version=resolve_filepath(model_id),first_not_none(model_version,default=generate_hash_from_file(model_id))
|
||||
return f'{backend}-{normalise_model_name(model_id)}',model_version
|
||||
if validate_is_path(model_id):
|
||||
model_id, model_version = (
|
||||
resolve_filepath(model_id),
|
||||
first_not_none(model_version, default=generate_hash_from_file(model_id)),
|
||||
)
|
||||
return f'{backend}-{normalise_model_name(model_id)}', model_version
|
||||
|
||||
@functools.cached_property
|
||||
def _has_gpus(self):
|
||||
try:
|
||||
from cuda import cuda
|
||||
err,*_=cuda.cuInit(0)
|
||||
if err!=cuda.CUresult.CUDA_SUCCESS:raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
err,num_gpus=cuda.cuDeviceGetCount()
|
||||
if err!=cuda.CUresult.CUDA_SUCCESS:raise RuntimeError('Failed to get CUDA device count.')
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
err, num_gpus = cuda.cuDeviceGetCount()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to get CUDA device count.')
|
||||
return True
|
||||
except (ImportError, RuntimeError):return False
|
||||
except (ImportError, RuntimeError):
|
||||
return False
|
||||
|
||||
@property
|
||||
def _torch_dtype(self):
|
||||
import torch, transformers
|
||||
_map=_torch_dtype_mapping()
|
||||
if not isinstance(self.__llm_torch_dtype__,torch.dtype):
|
||||
try:hf_config=transformers.AutoConfig.from_pretrained(self.bentomodel.path,trust_remote_code=self.trust_remote_code)
|
||||
except OpenLLMException:hf_config=transformers.AutoConfig.from_pretrained(self.model_id,trust_remote_code=self.trust_remote_code)
|
||||
config_dtype=getattr(hf_config,'torch_dtype',None)
|
||||
if config_dtype is None:config_dtype=torch.float32
|
||||
if self.__llm_dtype__=='auto':
|
||||
if config_dtype==torch.float32:torch_dtype=torch.float16
|
||||
else:torch_dtype=config_dtype
|
||||
|
||||
_map = _torch_dtype_mapping()
|
||||
if not isinstance(self.__llm_torch_dtype__, torch.dtype):
|
||||
try:
|
||||
hf_config = transformers.AutoConfig.from_pretrained(
|
||||
self.bentomodel.path, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
except OpenLLMException:
|
||||
hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
|
||||
config_dtype = getattr(hf_config, 'torch_dtype', None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
if self.__llm_dtype__ == 'auto':
|
||||
if config_dtype == torch.float32:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if self.__llm_dtype__ not in _map:raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'")
|
||||
torch_dtype=_map[self.__llm_dtype__]
|
||||
self.__llm_torch_dtype__=torch_dtype
|
||||
if self.__llm_dtype__ not in _map:
|
||||
raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'")
|
||||
torch_dtype = _map[self.__llm_dtype__]
|
||||
self.__llm_torch_dtype__ = torch_dtype
|
||||
return self.__llm_torch_dtype__
|
||||
|
||||
@property
|
||||
def _model_attrs(self):return {**self.import_kwargs[0],**self.__model_attrs}
|
||||
def _model_attrs(self):
|
||||
return {**self.import_kwargs[0], **self.__model_attrs}
|
||||
|
||||
@_model_attrs.setter
|
||||
def _model_attrs(self, value):self.__model_attrs = value
|
||||
def _model_attrs(self, value):
|
||||
self.__model_attrs = value
|
||||
|
||||
@property
|
||||
def _tokenizer_attrs(self):return {**self.import_kwargs[1],**self.__tokenizer_attrs}
|
||||
def _cascade_backend(self)->LiteralBackend:
|
||||
def _tokenizer_attrs(self):
|
||||
return {**self.import_kwargs[1], **self.__tokenizer_attrs}
|
||||
|
||||
def _cascade_backend(self) -> LiteralBackend:
|
||||
if self._has_gpus:
|
||||
if is_vllm_available():return 'vllm'
|
||||
elif is_ctranslate_available():return 'ctranslate' # XXX: base OpenLLM image should always include vLLM
|
||||
elif is_ctranslate_available():return 'ctranslate'
|
||||
else:return 'pt'
|
||||
def __setattr__(self,attr,value):
|
||||
if attr in {'model', 'tokenizer', 'runner', 'import_kwargs'}:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
if is_vllm_available():
|
||||
return 'vllm'
|
||||
elif is_ctranslate_available():
|
||||
return 'ctranslate' # XXX: base OpenLLM image should always include vLLM
|
||||
elif is_ctranslate_available():
|
||||
return 'ctranslate'
|
||||
else:
|
||||
return 'pt'
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
if attr in {'model', 'tokenizer', 'runner', 'import_kwargs'}:
|
||||
raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
super().__setattr__(attr, value)
|
||||
def __del__(self):del self.__llm_model__,self.__llm_tokenizer__,self.__llm_adapter_map__
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
del self.__llm_model__, self.__llm_tokenizer__, self.__llm_adapter_map__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def __repr_keys__(self):return {'model_id','revision','backend','type'}
|
||||
def __repr_keys__(self):
|
||||
return {'model_id', 'revision', 'backend', 'type'}
|
||||
|
||||
def __repr_args__(self):
|
||||
yield 'model_id',self._model_id if not self._local else self.tag.name
|
||||
yield 'revision',self._revision if self._revision else self.tag.version
|
||||
yield 'backend',self.__llm_backend__
|
||||
yield 'type',self.llm_type
|
||||
yield 'model_id', self._model_id if not self._local else self.tag.name
|
||||
yield 'revision', self._revision if self._revision else self.tag.version
|
||||
yield 'backend', self.__llm_backend__
|
||||
yield 'type', self.llm_type
|
||||
|
||||
@property
|
||||
def import_kwargs(self):return {'device_map':'auto' if self._has_gpus else None,'torch_dtype':self._torch_dtype},{'padding_side':'left','truncation_side':'left'}
|
||||
def import_kwargs(self):
|
||||
return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {
|
||||
'padding_side': 'left',
|
||||
'truncation_side': 'left',
|
||||
}
|
||||
|
||||
@property
|
||||
def trust_remote_code(self):
|
||||
env=os.getenv('TRUST_REMOTE_CODE')
|
||||
if env is not None:return str(env).upper() in ENV_VARS_TRUE_VALUES
|
||||
env = os.getenv('TRUST_REMOTE_CODE')
|
||||
if env is not None:
|
||||
check_bool_env('TRUST_REMOTE_CODE', env)
|
||||
return self.__llm_trust_remote_code__
|
||||
|
||||
@property
|
||||
def model_id(self):return self._model_id
|
||||
def model_id(self):
|
||||
return self._model_id
|
||||
|
||||
@property
|
||||
def revision(self):return self._revision
|
||||
def revision(self):
|
||||
return self._revision
|
||||
|
||||
@property
|
||||
def tag(self):return self._tag
|
||||
def tag(self):
|
||||
return self._tag
|
||||
|
||||
@property
|
||||
def bentomodel(self):return openllm.serialisation.get(self)
|
||||
def bentomodel(self):
|
||||
return openllm.serialisation.get(self)
|
||||
|
||||
@property
|
||||
def quantization_config(self):
|
||||
if self.__llm_quantization_config__ is None:
|
||||
from ._quantisation import infer_quantisation_config
|
||||
if self._quantization_config is not None:self.__llm_quantization_config__ = self._quantization_config
|
||||
elif self._quantise is not None:self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(self,self._quantise,**self._model_attrs)
|
||||
else:raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
|
||||
|
||||
if self._quantization_config is not None:
|
||||
self.__llm_quantization_config__ = self._quantization_config
|
||||
elif self._quantise is not None:
|
||||
self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(
|
||||
self, self._quantise, **self._model_attrs
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
|
||||
return self.__llm_quantization_config__
|
||||
|
||||
@property
|
||||
def has_adapters(self):return self._adapter_map is not None
|
||||
def has_adapters(self):
|
||||
return self._adapter_map is not None
|
||||
|
||||
@property
|
||||
def local(self):return self._local
|
||||
def local(self):
|
||||
return self._local
|
||||
|
||||
@property
|
||||
def quantise(self):return self._quantise
|
||||
def quantise(self):
|
||||
return self._quantise
|
||||
|
||||
@property
|
||||
def llm_type(self):return normalise_model_name(self._model_id)
|
||||
def llm_type(self):
|
||||
return normalise_model_name(self._model_id)
|
||||
|
||||
@property
|
||||
def llm_parameters(self):return (self._model_decls,self._model_attrs),self._tokenizer_attrs
|
||||
def llm_parameters(self):
|
||||
return (self._model_decls, self._model_attrs), self._tokenizer_attrs
|
||||
|
||||
@property
|
||||
def identifying_params(self):return {'configuration':self.config.model_dump_json().decode(),'model_ids':orjson.dumps(self.config['model_ids']).decode(),'model_id':self.model_id}
|
||||
def identifying_params(self):
|
||||
return {
|
||||
'configuration': self.config.model_dump_json().decode(),
|
||||
'model_ids': orjson.dumps(self.config['model_ids']).decode(),
|
||||
'model_id': self.model_id,
|
||||
}
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
if self.__llm_tokenizer__ is None:self.__llm_tokenizer__=openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
if self.__llm_tokenizer__ is None:
|
||||
self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
return self.__llm_tokenizer__
|
||||
|
||||
@property
|
||||
def runner(self):
|
||||
from ._runners import runner
|
||||
if self.__llm_runner__ is None:self.__llm_runner__=runner(self)
|
||||
|
||||
if self.__llm_runner__ is None:
|
||||
self.__llm_runner__ = runner(self)
|
||||
return self.__llm_runner__
|
||||
def prepare(self,adapter_type='lora',use_gradient_checking=True,**attrs):
|
||||
if self.__llm_backend__!='pt':raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
|
||||
|
||||
def prepare(self, adapter_type='lora', use_gradient_checking=True, **attrs):
|
||||
if self.__llm_backend__ != 'pt':
|
||||
raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
|
||||
from peft.mapping import get_peft_model
|
||||
from peft.utils.other import prepare_model_for_kbit_training
|
||||
model=get_peft_model(
|
||||
prepare_model_for_kbit_training(self.model,use_gradient_checkpointing=use_gradient_checking),
|
||||
|
||||
model = get_peft_model(
|
||||
prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking),
|
||||
self.config['fine_tune_strategies']
|
||||
.get(adapter_type,self.config.make_fine_tune_config(adapter_type))
|
||||
.get(adapter_type, self.config.make_fine_tune_config(adapter_type))
|
||||
.train()
|
||||
.with_config(**attrs)
|
||||
.build(),
|
||||
)
|
||||
if DEBUG:model.print_trainable_parameters()
|
||||
return model,self.tokenizer
|
||||
def prepare_for_training(self,*args,**attrs):logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Please use `prepare` instead.');return self.prepare(*args,**attrs)
|
||||
if DEBUG:
|
||||
model.print_trainable_parameters()
|
||||
return model, self.tokenizer
|
||||
|
||||
def prepare_for_training(self, *args, **attrs):
|
||||
logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Use `prepare` instead.')
|
||||
return self.prepare(*args, **attrs)
|
||||
|
||||
@property
|
||||
def adapter_map(self):
|
||||
@@ -431,33 +544,49 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
return self.__llm_config__
|
||||
|
||||
|
||||
# fmt: off
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _torch_dtype_mapping()->dict[str,torch.dtype]:
|
||||
import torch; return {
|
||||
def _torch_dtype_mapping() -> dict[str, torch.dtype]:
|
||||
import torch
|
||||
|
||||
return {
|
||||
'half': torch.float16,
|
||||
'float': torch.float32,
|
||||
'float16': torch.float16,
|
||||
'float': torch.float32,
|
||||
'float32': torch.float32,
|
||||
'bfloat16': torch.bfloat16,
|
||||
}
|
||||
def normalise_model_name(name:str)->str:return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else inflection.dasherize(name.replace('/','--'))
|
||||
def convert_peft_config_type(adapter_map:dict[str, str])->AdapterMap:
|
||||
if not is_peft_available():raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
|
||||
|
||||
def normalise_model_name(name: str) -> str:
|
||||
return (
|
||||
os.path.basename(resolve_filepath(name))
|
||||
if validate_is_path(name)
|
||||
else inflection.dasherize(name.replace('/', '--'))
|
||||
)
|
||||
|
||||
|
||||
def convert_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
if not is_peft_available():
|
||||
raise RuntimeError(
|
||||
"LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
resolved:AdapterMap={}
|
||||
resolved: AdapterMap = {}
|
||||
for path_or_adapter_id, name in adapter_map.items():
|
||||
if name is None:raise ValueError('Adapter name must be specified.')
|
||||
if name is None:
|
||||
raise ValueError('Adapter name must be specified.')
|
||||
if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)):
|
||||
config_file=os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
else:
|
||||
try:
|
||||
config_file=hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
config_file = hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
except Exception as err:
|
||||
raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'") from err
|
||||
with open(config_file, 'r') as file:resolved_config=orjson.loads(file.read())
|
||||
_peft_type=resolved_config['peft_type'].lower()
|
||||
if _peft_type not in resolved:resolved[_peft_type]=()
|
||||
resolved[_peft_type]+=(_AdapterTuple((path_or_adapter_id, name, resolved_config)),)
|
||||
with open(config_file, 'r') as file:
|
||||
resolved_config = orjson.loads(file.read())
|
||||
_peft_type = resolved_config['peft_type'].lower()
|
||||
if _peft_type not in resolved:
|
||||
resolved[_peft_type] = ()
|
||||
resolved[_peft_type] += (_AdapterTuple((path_or_adapter_id, name, resolved_config)),)
|
||||
return resolved
|
||||
|
||||
Reference in New Issue
Block a user