mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-18 22:55:08 -05:00
feat(openai): chat templates and complete control of prompt generation (#725)
* feat(openai): chat templates and complete control of prompt generation Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * fix: correctly use base chat templates Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * fix: remove symlink Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -140,9 +140,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
# The below are mainly for internal implementation that you don't have to worry about.
|
||||
_model_id: str
|
||||
_revision: t.Optional[str]
|
||||
_quantization_config: t.Optional[
|
||||
t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]
|
||||
]
|
||||
_quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]]
|
||||
_quantise: t.Optional[LiteralQuantise]
|
||||
_model_decls: TupleAny
|
||||
__model_attrs: DictStrAny
|
||||
@@ -157,9 +155,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
__llm_torch_dtype__: 'torch.dtype' = None
|
||||
__llm_config__: t.Optional[LLMConfig] = None
|
||||
__llm_backend__: LiteralBackend = None
|
||||
__llm_quantization_config__: t.Optional[
|
||||
t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]
|
||||
] = None
|
||||
__llm_quantization_config__: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]] = None
|
||||
__llm_runner__: t.Optional[Runner[M, T]] = None
|
||||
__llm_model__: t.Optional[M] = None
|
||||
__llm_tokenizer__: t.Optional[T] = None
|
||||
@@ -195,18 +191,18 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
)
|
||||
dtype = torch_dtype
|
||||
_local = False
|
||||
if validate_is_path(model_id):
|
||||
model_id, _local = resolve_filepath(model_id), True
|
||||
backend = first_not_none(getenv('backend', default=backend), default=self._cascade_backend())
|
||||
dtype = first_not_none(getenv('dtype', default=dtype, var=['TORCH_DTYPE']), default='auto')
|
||||
quantize = first_not_none(getenv('quantize', default=quantize, var=['QUANITSE']), default=None)
|
||||
if validate_is_path(model_id): model_id, _local = resolve_filepath(model_id), True
|
||||
backend = getenv('backend', default=backend)
|
||||
if backend is None: backend = self._cascade_backend()
|
||||
dtype = getenv('dtype', default=dtype, var=['TORCH_DTYPE'])
|
||||
if dtype is None: logger.warning('Setting dtype to auto. Inferring from framework specific models'); dtype = 'auto'
|
||||
quantize = getenv('quantize', default=quantize, var=['QUANITSE'])
|
||||
attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage})
|
||||
# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
|
||||
model_attrs, tokenizer_attrs = flatten_attrs(**attrs)
|
||||
if model_tag is None:
|
||||
model_tag, model_version = self._make_tag_components(model_id, model_version, backend=backend)
|
||||
if model_version:
|
||||
model_tag = f'{model_tag}:{model_version}'
|
||||
if model_version: model_tag = f'{model_tag}:{model_version}'
|
||||
|
||||
self.__attrs_init__(
|
||||
model_id=model_id,
|
||||
@@ -233,102 +229,68 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
|
||||
# resolve the tag
|
||||
self._tag = model.tag
|
||||
if not _eager and embedded:
|
||||
raise RuntimeError("Embedded mode is not supported when '_eager' is False.")
|
||||
if not _eager and embedded: raise RuntimeError("Embedded mode is not supported when '_eager' is False.")
|
||||
if embedded:
|
||||
logger.warning(
|
||||
'NOT RECOMMENDED in production and SHOULD ONLY used for development (Loading into current memory).'
|
||||
)
|
||||
self.runner.init_local(quiet=True)
|
||||
|
||||
logger.warning('NOT RECOMMENDED in production and SHOULD ONLY used for development.'); self.runner.init_local(quiet=True)
|
||||
class _Quantise:
|
||||
@staticmethod
|
||||
def pt(llm: LLM, quantise=None):
|
||||
return quantise
|
||||
|
||||
def pt(llm: LLM, quantise=None): return quantise
|
||||
@staticmethod
|
||||
def vllm(llm: LLM, quantise=None):
|
||||
return quantise
|
||||
|
||||
def vllm(llm: LLM, quantise=None): return quantise
|
||||
@staticmethod
|
||||
def ctranslate(llm: LLM, quantise=None):
|
||||
if quantise in {'int4', 'awq', 'gptq', 'squeezellm'}:
|
||||
raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'")
|
||||
if quantise == 'int8':
|
||||
quantise = 'int8_float16' if llm._has_gpus else 'int8_float32'
|
||||
if quantise in {'int4', 'awq', 'gptq', 'squeezellm'}: raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'")
|
||||
if quantise == 'int8': quantise = 'int8_float16' if llm._has_gpus else 'int8_float32'
|
||||
return quantise
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self, model_id: str, model_version: str | None, backend: str) -> tuple[str, str | None]:
|
||||
model_id, *maybe_revision = model_id.rsplit(':')
|
||||
if len(maybe_revision) > 0:
|
||||
if model_version is not None:
|
||||
logger.warning(
|
||||
"revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version
|
||||
)
|
||||
if model_version is not None: logger.warning("revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version)
|
||||
model_version = maybe_revision[0]
|
||||
if validate_is_path(model_id):
|
||||
model_id, model_version = (
|
||||
resolve_filepath(model_id),
|
||||
first_not_none(model_version, default=generate_hash_from_file(model_id)),
|
||||
)
|
||||
model_id, model_version = resolve_filepath(model_id), first_not_none(model_version, default=generate_hash_from_file(model_id))
|
||||
return f'{backend}-{normalise_model_name(model_id)}', model_version
|
||||
|
||||
@functools.cached_property
|
||||
def _has_gpus(self):
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
err, num_gpus = cuda.cuDeviceGetCount()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to get CUDA device count.')
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to get CUDA device count.')
|
||||
return True
|
||||
except (ImportError, RuntimeError):
|
||||
return False
|
||||
|
||||
@property
|
||||
def _torch_dtype(self):
|
||||
import torch, transformers
|
||||
|
||||
_map = _torch_dtype_mapping()
|
||||
if not isinstance(self.__llm_torch_dtype__, torch.dtype):
|
||||
try:
|
||||
hf_config = transformers.AutoConfig.from_pretrained(
|
||||
self.bentomodel.path, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
hf_config = transformers.AutoConfig.from_pretrained(self.bentomodel.path, trust_remote_code=self.trust_remote_code)
|
||||
except OpenLLMException:
|
||||
hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
|
||||
config_dtype = getattr(hf_config, 'torch_dtype', None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
if config_dtype is None: config_dtype = torch.float32
|
||||
if self.__llm_dtype__ == 'auto':
|
||||
if config_dtype == torch.float32:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if self.__llm_dtype__ not in _map:
|
||||
raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'")
|
||||
if self.__llm_dtype__ not in _map: raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'")
|
||||
torch_dtype = _map[self.__llm_dtype__]
|
||||
self.__llm_torch_dtype__ = torch_dtype
|
||||
return self.__llm_torch_dtype__
|
||||
|
||||
@property
|
||||
def _model_attrs(self):
|
||||
return {**self.import_kwargs[0], **self.__model_attrs}
|
||||
|
||||
def _model_attrs(self): return {**self.import_kwargs[0], **self.__model_attrs}
|
||||
@_model_attrs.setter
|
||||
def _model_attrs(self, value):
|
||||
self.__model_attrs = value
|
||||
|
||||
def _model_attrs(self, value): self.__model_attrs = value
|
||||
@property
|
||||
def _tokenizer_attrs(self):
|
||||
return {**self.import_kwargs[1], **self.__tokenizer_attrs}
|
||||
|
||||
def _tokenizer_attrs(self): return {**self.import_kwargs[1], **self.__tokenizer_attrs}
|
||||
def _cascade_backend(self) -> LiteralBackend:
|
||||
logger.warning('It is recommended to specify the backend explicitly. Cascading backend might lead to unexpected behaviour.')
|
||||
if self._has_gpus:
|
||||
if is_vllm_available():
|
||||
return 'vllm'
|
||||
@@ -338,93 +300,53 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
return 'ctranslate'
|
||||
else:
|
||||
return 'pt'
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
if attr in {'model', 'tokenizer', 'runner', 'import_kwargs'}:
|
||||
raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
if attr in {'model', 'tokenizer', 'runner', 'import_kwargs'}: raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
del self.__llm_model__, self.__llm_tokenizer__, self.__llm_adapter_map__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def __repr_keys__(self):
|
||||
return {'model_id', 'revision', 'backend', 'type'}
|
||||
|
||||
def __repr_args__(self):
|
||||
yield 'model_id', self._model_id if not self._local else self.tag.name
|
||||
yield 'revision', self._revision if self._revision else self.tag.version
|
||||
yield 'backend', self.__llm_backend__
|
||||
yield 'type', self.llm_type
|
||||
|
||||
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
|
||||
def __repr_args__(self): yield from (('model_id', self._model_id if not self._local else self.tag.name), ('revision', self._revision if self._revision else self.tag.version), ('backend', self.__llm_backend__), ('type', self.llm_type))
|
||||
@property
|
||||
def import_kwargs(self):
|
||||
return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {
|
||||
'padding_side': 'left',
|
||||
'truncation_side': 'left',
|
||||
}
|
||||
|
||||
def import_kwargs(self): return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
@property
|
||||
def trust_remote_code(self):
|
||||
env = os.getenv('TRUST_REMOTE_CODE')
|
||||
if env is not None:
|
||||
check_bool_env('TRUST_REMOTE_CODE', env)
|
||||
if env is not None: return check_bool_env('TRUST_REMOTE_CODE', env)
|
||||
return self.__llm_trust_remote_code__
|
||||
|
||||
@property
|
||||
def model_id(self):
|
||||
return self._model_id
|
||||
|
||||
def model_id(self): return self._model_id
|
||||
@property
|
||||
def revision(self):
|
||||
return self._revision
|
||||
|
||||
def revision(self): return self._revision
|
||||
@property
|
||||
def tag(self):
|
||||
return self._tag
|
||||
|
||||
def tag(self): return self._tag
|
||||
@property
|
||||
def bentomodel(self):
|
||||
return openllm.serialisation.get(self)
|
||||
|
||||
def bentomodel(self): return openllm.serialisation.get(self)
|
||||
@property
|
||||
def quantization_config(self):
|
||||
if self.__llm_quantization_config__ is None:
|
||||
from ._quantisation import infer_quantisation_config
|
||||
|
||||
if self._quantization_config is not None:
|
||||
self.__llm_quantization_config__ = self._quantization_config
|
||||
elif self._quantise is not None:
|
||||
self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(
|
||||
self, self._quantise, **self._model_attrs
|
||||
)
|
||||
self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(self, self._quantise, **self._model_attrs)
|
||||
else:
|
||||
raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
|
||||
return self.__llm_quantization_config__
|
||||
|
||||
@property
|
||||
def has_adapters(self):
|
||||
return self._adapter_map is not None
|
||||
|
||||
def has_adapters(self): return self._adapter_map is not None
|
||||
@property
|
||||
def local(self):
|
||||
return self._local
|
||||
|
||||
def local(self): return self._local
|
||||
@property
|
||||
def quantise(self):
|
||||
return self._quantise
|
||||
|
||||
def quantise(self): return self._quantise
|
||||
@property
|
||||
def llm_type(self):
|
||||
return normalise_model_name(self._model_id)
|
||||
|
||||
def llm_type(self): return normalise_model_name(self._model_id)
|
||||
@property
|
||||
def llm_parameters(self):
|
||||
return (self._model_decls, self._model_attrs), self._tokenizer_attrs
|
||||
|
||||
def llm_parameters(self): return (self._model_decls, self._model_attrs), self._tokenizer_attrs
|
||||
@property
|
||||
def identifying_params(self):
|
||||
return {
|
||||
@@ -432,24 +354,17 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
'model_ids': orjson.dumps(self.config['model_ids']).decode(),
|
||||
'model_id': self.model_id,
|
||||
}
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
if self.__llm_tokenizer__ is None:
|
||||
self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
if self.__llm_tokenizer__ is None: self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
return self.__llm_tokenizer__
|
||||
|
||||
@property
|
||||
def runner(self):
|
||||
from ._runners import runner
|
||||
|
||||
if self.__llm_runner__ is None:
|
||||
self.__llm_runner__ = runner(self)
|
||||
if self.__llm_runner__ is None: self.__llm_runner__ = runner(self)
|
||||
return self.__llm_runner__
|
||||
|
||||
def prepare(self, adapter_type='lora', use_gradient_checking=True, **attrs):
|
||||
if self.__llm_backend__ != 'pt':
|
||||
raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
|
||||
if self.__llm_backend__ != 'pt': raise RuntimeError('Fine tuning is only supported for PyTorch backend.')
|
||||
from peft.mapping import get_peft_model
|
||||
from peft.utils.other import prepare_model_for_kbit_training
|
||||
|
||||
@@ -457,37 +372,25 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking),
|
||||
self.config['fine_tune_strategies']
|
||||
.get(adapter_type, self.config.make_fine_tune_config(adapter_type))
|
||||
.train()
|
||||
.with_config(**attrs)
|
||||
.build(),
|
||||
.train().with_config(**attrs).build(),
|
||||
)
|
||||
if DEBUG:
|
||||
model.print_trainable_parameters()
|
||||
if DEBUG: model.print_trainable_parameters()
|
||||
return model, self.tokenizer
|
||||
|
||||
def prepare_for_training(self, *args, **attrs):
|
||||
logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Use `prepare` instead.')
|
||||
return self.prepare(*args, **attrs)
|
||||
|
||||
def prepare_for_training(self, *args, **attrs): logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Use `prepare` instead.'); return self.prepare(*args, **attrs)
|
||||
@property
|
||||
def adapter_map(self):
|
||||
if not is_peft_available():
|
||||
raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
if not self.has_adapters:
|
||||
raise AttributeError('Adapter map is not available.')
|
||||
if not is_peft_available(): raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
if not self.has_adapters: raise AttributeError('Adapter map is not available.')
|
||||
assert self._adapter_map is not None
|
||||
if self.__llm_adapter_map__ is None:
|
||||
_map: ResolvedAdapterMap = {k: {} for k in self._adapter_map}
|
||||
for adapter_type, adapter_tuple in self._adapter_map.items():
|
||||
base = first_not_none(
|
||||
self.config['fine_tune_strategies'].get(adapter_type),
|
||||
default=self.config.make_fine_tune_config(adapter_type),
|
||||
self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type),
|
||||
)
|
||||
for adapter in adapter_tuple:
|
||||
_map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
|
||||
for adapter in adapter_tuple: _map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
|
||||
self.__llm_adapter_map__ = _map
|
||||
return self.__llm_adapter_map__
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
if self.__llm_model__ is None:
|
||||
@@ -495,12 +398,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
# If OOM, then it is probably you don't have enough VRAM to run this model.
|
||||
if self.__llm_backend__ == 'pt':
|
||||
import torch
|
||||
|
||||
loaded_in_kbit = (
|
||||
getattr(model, 'is_loaded_in_8bit', False)
|
||||
or getattr(model, 'is_loaded_in_4bit', False)
|
||||
or getattr(model, 'is_quantized', False)
|
||||
)
|
||||
loaded_in_kbit = getattr(model, 'is_loaded_in_8bit', False) or getattr(model, 'is_loaded_in_4bit', False) or getattr(model, 'is_quantized', False)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit:
|
||||
try:
|
||||
model = model.to('cuda')
|
||||
@@ -513,17 +411,13 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
model.load_adapter(peft_model_id, adapter_name, peft_config=peft_config)
|
||||
self.__llm_model__ = model
|
||||
return self.__llm_model__
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
import transformers
|
||||
|
||||
if self.__llm_config__ is None:
|
||||
if self.__llm_backend__ == 'ctranslate':
|
||||
try:
|
||||
config = transformers.AutoConfig.from_pretrained(
|
||||
self.bentomodel.path_of('/hf'), trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
config = transformers.AutoConfig.from_pretrained(self.bentomodel.path_of('/hf'), trust_remote_code=self.trust_remote_code)
|
||||
except OpenLLMException:
|
||||
config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
|
||||
for architecture in config.architectures:
|
||||
@@ -533,47 +427,26 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
).model_construct_env(**self._model_attrs)
|
||||
break
|
||||
else:
|
||||
raise OpenLLMException(
|
||||
f"Failed to infer the configuration class from the given model. Make sure the model is a supported model. Supported models are: {', '.join(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE.keys())}"
|
||||
)
|
||||
raise OpenLLMException(f"Failed to infer the configuration class. Make sure the model is a supported model. Supported models are: {', '.join(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE.keys())}")
|
||||
else:
|
||||
config = openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
|
||||
self.__llm_config__ = config
|
||||
return self.__llm_config__
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _torch_dtype_mapping() -> dict[str, torch.dtype]:
|
||||
import torch
|
||||
|
||||
return {
|
||||
'half': torch.float16,
|
||||
'float16': torch.float16,
|
||||
'float': torch.float32,
|
||||
'float32': torch.float32,
|
||||
import torch; return {
|
||||
'half': torch.float16, 'float16': torch.float16,
|
||||
'float': torch.float32, 'float32': torch.float32,
|
||||
'bfloat16': torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def normalise_model_name(name: str) -> str:
|
||||
return (
|
||||
os.path.basename(resolve_filepath(name))
|
||||
if validate_is_path(name)
|
||||
else inflection.dasherize(name.replace('/', '--'))
|
||||
)
|
||||
|
||||
|
||||
def normalise_model_name(name: str) -> str: return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else inflection.dasherize(name.replace('/', '--'))
|
||||
def convert_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
if not is_peft_available():
|
||||
raise RuntimeError(
|
||||
"LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
if not is_peft_available(): raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
resolved: AdapterMap = {}
|
||||
for path_or_adapter_id, name in adapter_map.items():
|
||||
if name is None:
|
||||
raise ValueError('Adapter name must be specified.')
|
||||
if name is None: raise ValueError('Adapter name must be specified.')
|
||||
if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)):
|
||||
config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
else:
|
||||
@@ -584,7 +457,6 @@ def convert_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
with open(config_file, 'r') as file:
|
||||
resolved_config = orjson.loads(file.read())
|
||||
_peft_type = resolved_config['peft_type'].lower()
|
||||
if _peft_type not in resolved:
|
||||
resolved[_peft_type] = ()
|
||||
if _peft_type not in resolved: resolved[_peft_type] = ()
|
||||
resolved[_peft_type] += (_AdapterTuple((path_or_adapter_id, name, resolved_config)),)
|
||||
return resolved
|
||||
|
||||
@@ -6,6 +6,7 @@ import _service_vars as svars
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm_core._schemas import MessageParam
|
||||
from bentoml.io import JSON, Text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -61,11 +62,6 @@ class MessagesConverterInput(t.TypedDict):
|
||||
messages: t.List[t.Dict[str, t.Any]]
|
||||
|
||||
|
||||
class MessageParam(t.TypedDict):
|
||||
role: t.Literal['system', 'user', 'assistant']
|
||||
content: str
|
||||
|
||||
|
||||
@svc.api(
|
||||
route='/v1/helpers/messages',
|
||||
input=JSON.from_sample(
|
||||
|
||||
@@ -64,17 +64,15 @@ requestBody:
|
||||
one-shot:
|
||||
summary: One-shot input example
|
||||
value:
|
||||
messages:
|
||||
- role: system
|
||||
content: You are a helpful assistant.
|
||||
- role: user
|
||||
content: Hello, I'm looking for a chatbot that can help me with my work.
|
||||
messages: __chat_messages__
|
||||
model: __model_id__
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
top_p: 0.43
|
||||
n: 1
|
||||
stream: false
|
||||
chat_template: __chat_template__
|
||||
add_generation_prompt: __add_generation_prompt__
|
||||
streaming:
|
||||
summary: Streaming input example
|
||||
value:
|
||||
@@ -92,6 +90,8 @@ requestBody:
|
||||
stop:
|
||||
- "\\n"
|
||||
- "<|endoftext|>"
|
||||
chat_template: __chat_template__
|
||||
add_generation_prompt: __add_generation_prompt__
|
||||
schema:
|
||||
$ref: '#/components/schemas/ChatCompletionRequest'
|
||||
responses:
|
||||
|
||||
@@ -56,24 +56,16 @@ schemas = get_generator(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def jsonify_attr(obj):
|
||||
return orjson.dumps(converter.unstructure(obj)).decode()
|
||||
|
||||
def jsonify_attr(obj): return orjson.dumps(converter.unstructure(obj)).decode()
|
||||
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(
|
||||
{
|
||||
'error': converter.unstructure(
|
||||
ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value))
|
||||
)
|
||||
},
|
||||
{'error': converter.unstructure(ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value)))},
|
||||
status_code=status_code.value,
|
||||
)
|
||||
|
||||
|
||||
async def check_model(request, model):
|
||||
if request.model == model:
|
||||
return None
|
||||
if request.model == model: return None
|
||||
return error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
@@ -93,7 +85,6 @@ def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm):
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in id_logprob.items()})
|
||||
return logprobs
|
||||
|
||||
@@ -106,7 +97,9 @@ def mount_to_svc(svc, llm):
|
||||
debug=True,
|
||||
routes=[
|
||||
Route(
|
||||
'/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']
|
||||
'/models',
|
||||
functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm),
|
||||
methods=['GET']
|
||||
),
|
||||
Route(
|
||||
'/completions',
|
||||
@@ -115,7 +108,11 @@ def mount_to_svc(svc, llm):
|
||||
),
|
||||
Route(
|
||||
'/chat/completions',
|
||||
functools.partial(apply_schema(chat_completions, __model_id__=llm.llm_type), llm=llm),
|
||||
functools.partial(apply_schema(chat_completions,
|
||||
__model_id__=llm.llm_type,
|
||||
__chat_template__=orjson.dumps(llm.config.chat_template).decode(),
|
||||
__chat_messages__=orjson.dumps(llm.config.chat_messages).decode(),
|
||||
__add_generation_prompt__=str(True) if llm.config.chat_messages is not None else str(False)), llm=llm),
|
||||
methods=['POST'],
|
||||
),
|
||||
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
|
||||
@@ -127,11 +124,7 @@ def mount_to_svc(svc, llm):
|
||||
|
||||
# GET /v1/models
|
||||
@add_schema_definitions
|
||||
def list_models(_, llm):
|
||||
return JSONResponse(
|
||||
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
def list_models(_, llm): return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value)
|
||||
|
||||
# POST /v1/chat/completions
|
||||
@add_schema_definitions
|
||||
@@ -141,27 +134,22 @@ async def chat_completions(req, llm):
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), ChatCompletionRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
logger.debug('Sent body: %s', json_str); logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received chat completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
if err_check is not None: return err_check
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
|
||||
created_time = int(time.monotonic())
|
||||
prompt = llm.tokenizer.apply_chat_template(
|
||||
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
prompt = llm.tokenizer.apply_chat_template(request.messages, tokenize=False, chat_template=request.chat_template if request.chat_template != 'None' else None, add_generation_prompt=request.add_generation_prompt)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
traceback.print_exc(); logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_response_json(index, text, finish_reason=None, usage=None):
|
||||
@@ -169,9 +157,7 @@ async def chat_completions(req, llm):
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
|
||||
],
|
||||
choices=[ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)],
|
||||
)
|
||||
if usage is not None: response.usage = usage
|
||||
return jsonify_attr(response)
|
||||
@@ -194,20 +180,17 @@ async def chat_completions(req, llm):
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if request.stream:
|
||||
return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream')
|
||||
if request.stream: return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result = None
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
@@ -225,25 +208,18 @@ async def chat_completions(req, llm):
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
|
||||
)
|
||||
response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
|
||||
if request.stream: # type: ignore[unreachable]
|
||||
if request.stream:
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
async def fake_stream_generator() -> t.AsyncGenerator[str, None]: # type: ignore[unreachable]
|
||||
yield f'data: {jsonify_attr(response)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
return StreamingResponse(
|
||||
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
|
||||
)
|
||||
async def fake_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n'
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
traceback.print_exc(); logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
@@ -255,23 +231,17 @@ async def completions(req, llm):
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), CompletionRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
logger.debug('Sent body: %s', json_str); logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received legacy completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
if err_check is not None: return err_check
|
||||
|
||||
if request.echo:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
|
||||
if request.suffix is not None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
if request.echo: return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
|
||||
if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
|
||||
if not request.prompt:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
if not request.prompt: return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
prompt = request.prompt
|
||||
# TODO: Support multiple prompts
|
||||
|
||||
@@ -282,8 +252,7 @@ async def completions(req, llm):
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
traceback.print_exc(); logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
# best_of != n then we don't stream
|
||||
@@ -295,9 +264,7 @@ async def completions(req, llm):
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)
|
||||
],
|
||||
choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)],
|
||||
)
|
||||
if usage: response.usage = usage
|
||||
return jsonify_attr(response)
|
||||
@@ -308,12 +275,9 @@ async def completions(req, llm):
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
logprobs = None
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i]:], initial_text_offset=len(previous_texts[i]), llm=llm
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i]:], len(previous_texts[i]), llm=llm)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
previous_texts[i] += output.text
|
||||
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs)}\n\n'
|
||||
@@ -326,20 +290,17 @@ async def completions(req, llm):
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result = None
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
@@ -349,15 +310,10 @@ async def completions(req, llm):
|
||||
|
||||
choices = []
|
||||
for output in final_result.outputs:
|
||||
logprobs = None
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
token_ids=output.token_ids, id_logprobs=output.logprobs, llm=llm
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason
|
||||
)
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs, llm=llm)
|
||||
choice_data = CompletionResponseChoice(index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
@@ -369,13 +325,8 @@ async def completions(req, llm):
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
async def fake_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
yield f'data: {jsonify_attr(response)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
return StreamingResponse(
|
||||
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n'
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -62,6 +62,9 @@ class ChatCompletionRequest:
|
||||
# supported by vLLM and us
|
||||
top_k: t.Optional[int] = attr.field(default=None)
|
||||
best_of: t.Optional[int] = attr.field(default=1)
|
||||
# Additional features to support chat_template
|
||||
chat_template: str = attr.field(default=None)
|
||||
add_generation_prompt: bool = attr.field(default=True)
|
||||
|
||||
|
||||
@attr.define
|
||||
|
||||
Reference in New Issue
Block a user