mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-19 23:24:12 -05:00
chore(llm): expose quantise and lazy load heavy imports (#617)
* chore(llm): expose quantise and lazy load heavy imports Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: move transformers to TYPE_CHECKING block Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -9,8 +9,6 @@ import typing as t
|
||||
import attr
|
||||
import inflection
|
||||
import orjson
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
@@ -58,6 +56,7 @@ from .serialisation.constants import PEFT_CONFIG_NAME
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import peft
|
||||
import transformers
|
||||
|
||||
from bentoml._internal.runner.runnable import RunnableMethod
|
||||
from bentoml._internal.runner.runner import RunnerMethod
|
||||
@@ -124,8 +123,8 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
_quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None
|
||||
_quantise: LiteralQuantise | None
|
||||
_model_decls: TupleAny
|
||||
_model_attrs: DictStrAny
|
||||
_tokenizer_attrs: DictStrAny
|
||||
__model_attrs: DictStrAny
|
||||
__tokenizer_attrs: DictStrAny
|
||||
_tag: bentoml.Tag
|
||||
_adapter_map: AdapterMap | None
|
||||
_serialisation: LiteralSerialisation
|
||||
@@ -133,7 +132,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
_prompt_template: PromptTemplate | None
|
||||
_system_message: str | None
|
||||
|
||||
_bentomodel: bentoml.Model = attr.field(init=False)
|
||||
__llm_config__: LLMConfig | None = None
|
||||
__llm_backend__: LiteralBackend = None # type: ignore
|
||||
__llm_quantization_config__: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None
|
||||
@@ -168,16 +166,11 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
_local = False
|
||||
if validate_is_path(model_id):
|
||||
model_id, _local = resolve_filepath(model_id), True
|
||||
backend = t.cast(
|
||||
LiteralBackend,
|
||||
first_not_none(
|
||||
backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if openllm.utils.is_vllm_available() else 'pt'
|
||||
),
|
||||
)
|
||||
|
||||
quantize = first_not_none(
|
||||
quantize, t.cast(t.Optional[LiteralQuantise], os.getenv('OPENLLM_QUANTIZE')), default=None
|
||||
backend = first_not_none(
|
||||
backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if openllm.utils.is_vllm_available() else 'pt'
|
||||
)
|
||||
quantize = first_not_none(quantize, os.getenv('OPENLLM_QUANTIZE'), default=None)
|
||||
# elif quantization_config is None and quantize is not None:
|
||||
# quantization_config, attrs = infer_quantisation_config(self, quantize, **attrs)
|
||||
attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage})
|
||||
@@ -199,17 +192,17 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
self.__attrs_init__(
|
||||
model_id=model_id,
|
||||
revision=model_version,
|
||||
tag=bentoml.Tag.from_taglike(t.cast(t.Union[str, bentoml.Tag], model_tag)),
|
||||
tag=bentoml.Tag.from_taglike(model_tag),
|
||||
quantization_config=quantization_config,
|
||||
quantise=quantize,
|
||||
model_decls=args,
|
||||
model_attrs=dict(**self.import_kwargs[0], **model_attrs),
|
||||
tokenizer_attrs=dict(**self.import_kwargs[-1], **tokenizer_attrs),
|
||||
adapter_map=resolve_peft_config_type(adapter_map) if adapter_map is not None else None,
|
||||
serialisation=serialisation,
|
||||
local=_local,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
LLM__model_attrs=model_attrs,
|
||||
LLM__tokenizer_attrs=tokenizer_attrs,
|
||||
llm_backend__=backend,
|
||||
llm_config__=llm_config,
|
||||
llm_trust_remote_code__=trust_remote_code,
|
||||
@@ -221,7 +214,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
|
||||
# resolve the tag
|
||||
self._tag = model.tag
|
||||
self._bentomodel = model
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self, model_id, model_version, backend) -> tuple[str, str | None]:
|
||||
@@ -241,72 +233,141 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
)
|
||||
return f'{backend}-{normalise_model_name(model_id)}', model_version
|
||||
|
||||
# yapf: disable
|
||||
def __setattr__(self,attr,value):
|
||||
if attr in _reserved_namespace:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
super().__setattr__(attr,value)
|
||||
def __setattr__(self, attr, value):
|
||||
if attr in _reserved_namespace:
|
||||
raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
@property
|
||||
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
|
||||
def _model_attrs(self) -> dict[str, t.Any]:
|
||||
return {**self.import_kwargs[0], **self.__model_attrs}
|
||||
|
||||
@property
|
||||
def _tokenizer_attrs(self) -> dict[str, t.Any]:
|
||||
return {**self.import_kwargs[1], **self.__tokenizer_attrs}
|
||||
|
||||
@property
|
||||
def __repr_keys__(self):
|
||||
return {'model_id', 'revision', 'backend', 'type'}
|
||||
|
||||
def __repr_args__(self):
|
||||
yield 'model_id',self._model_id if not self._local else self.tag.name
|
||||
yield 'revision',self._revision if self._revision else self.tag.version
|
||||
yield 'backend',self.__llm_backend__
|
||||
yield 'type',self.llm_type
|
||||
yield 'model_id', self._model_id if not self._local else self.tag.name
|
||||
yield 'revision', self._revision if self._revision else self.tag.version
|
||||
yield 'backend', self.__llm_backend__
|
||||
yield 'type', self.llm_type
|
||||
|
||||
@property
|
||||
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
|
||||
return {
|
||||
'device_map': 'auto' if torch.cuda.is_available() else None,
|
||||
'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
|
||||
@property
|
||||
def trust_remote_code(self)->bool:return first_not_none(check_bool_env('TRUST_REMOTE_CODE',False),default=self.__llm_trust_remote_code__)
|
||||
def trust_remote_code(self) -> bool:
|
||||
return first_not_none(check_bool_env('TRUST_REMOTE_CODE', False), default=self.__llm_trust_remote_code__)
|
||||
|
||||
@property
|
||||
def runner_name(self)->str:return f"llm-{self.config['start_name']}-runner"
|
||||
def runner_name(self) -> str:
|
||||
return f"llm-{self.config['start_name']}-runner"
|
||||
|
||||
@property
|
||||
def model_id(self)->str:return self._model_id
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
@property
|
||||
def revision(self)->str:return t.cast(str, self._revision)
|
||||
def revision(self) -> str:
|
||||
return t.cast(str, self._revision)
|
||||
|
||||
@property
|
||||
def tag(self)->bentoml.Tag:return self._tag
|
||||
def tag(self) -> bentoml.Tag:
|
||||
return self._tag
|
||||
|
||||
@property
|
||||
def bentomodel(self)->bentoml.Model:return openllm.serialisation.get(self)
|
||||
def bentomodel(self) -> bentoml.Model:
|
||||
return openllm.serialisation.get(self)
|
||||
|
||||
@property
|
||||
def quantization_config(self)->transformers.BitsAndBytesConfig|transformers.GPTQConfig|transformers.AwqConfig:
|
||||
def quantization_config(self) -> transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig:
|
||||
if self.__llm_quantization_config__ is None:
|
||||
if self._quantization_config is not None:self.__llm_quantization_config__=self._quantization_config
|
||||
elif self._quantise is not None:self.__llm_quantization_config__,self._model_attrs=infer_quantisation_config(self, self._quantise, **self._model_attrs)
|
||||
else:raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
|
||||
if self._quantization_config is not None:
|
||||
self.__llm_quantization_config__ = self._quantization_config
|
||||
elif self._quantise is not None:
|
||||
self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(
|
||||
self, self._quantise, **self._model_attrs
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
|
||||
return self.__llm_quantization_config__
|
||||
|
||||
@property
|
||||
def has_adapters(self)->bool:return self._adapter_map is not None
|
||||
def has_adapters(self) -> bool:
|
||||
return self._adapter_map is not None
|
||||
|
||||
@property
|
||||
def local(self)->bool:return self._local
|
||||
def local(self) -> bool:
|
||||
return self._local
|
||||
|
||||
@property
|
||||
def quantise(self) -> LiteralQuantise | None:
|
||||
return self._quantise
|
||||
|
||||
# NOTE: The section below defines a loose contract with langchain's LLM interface.
|
||||
@property
|
||||
def llm_type(self)->str:return normalise_model_name(self._model_id)
|
||||
def llm_type(self) -> str:
|
||||
return normalise_model_name(self._model_id)
|
||||
|
||||
@property
|
||||
def identifying_params(self)->DictStrAny:return {'configuration': self.config.model_dump_json().decode(),'model_ids': orjson.dumps(self.config['model_ids']).decode(),'model_id': self.model_id}
|
||||
def identifying_params(self) -> DictStrAny:
|
||||
return {
|
||||
'configuration': self.config.model_dump_json().decode(),
|
||||
'model_ids': orjson.dumps(self.config['model_ids']).decode(),
|
||||
'model_id': self.model_id,
|
||||
}
|
||||
|
||||
@property
|
||||
def llm_parameters(self)->tuple[tuple[tuple[t.Any,...],DictStrAny],DictStrAny]:return (self._model_decls,self._model_attrs),self._tokenizer_attrs
|
||||
def llm_parameters(self) -> tuple[tuple[tuple[t.Any, ...], DictStrAny], DictStrAny]:
|
||||
return (self._model_decls, self._model_attrs), self._tokenizer_attrs
|
||||
|
||||
# NOTE: This section is the actual model, tokenizer, and config reference here.
|
||||
@property
|
||||
def config(self)->LLMConfig:
|
||||
if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
|
||||
def config(self) -> LLMConfig:
|
||||
if self.__llm_config__ is None:
|
||||
self.__llm_config__ = openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
|
||||
return self.__llm_config__
|
||||
|
||||
@property
|
||||
def tokenizer(self)->T:
|
||||
if self.__llm_tokenizer__ is None:self.__llm_tokenizer__=openllm.serialisation.load_tokenizer(self,**self.llm_parameters[-1])
|
||||
def tokenizer(self) -> T:
|
||||
if self.__llm_tokenizer__ is None:
|
||||
self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
return self.__llm_tokenizer__
|
||||
|
||||
@property
|
||||
def runner(self)->LLMRunner[M, T]:
|
||||
if self.__llm_runner__ is None:self.__llm_runner__=_RunnerFactory(self)
|
||||
def runner(self) -> LLMRunner[M, T]:
|
||||
if self.__llm_runner__ is None:
|
||||
self.__llm_runner__ = _RunnerFactory(self)
|
||||
return self.__llm_runner__
|
||||
|
||||
@property
|
||||
def model(self)->M:
|
||||
def model(self) -> M:
|
||||
if self.__llm_model__ is None:
|
||||
model=openllm.serialisation.load_model(self,*self._model_decls,**self._model_attrs)
|
||||
model = openllm.serialisation.load_model(self, *self._model_decls, **self._model_attrs)
|
||||
# If OOM, then it is probably you don't have enough VRAM to run this model.
|
||||
if self.__llm_backend__ == 'pt':
|
||||
loaded_in_kbit = getattr(model,'is_loaded_in_8bit',False) or getattr(model,'is_loaded_in_4bit',False) or getattr(model,'is_quantized',False)
|
||||
import torch
|
||||
|
||||
loaded_in_kbit = (
|
||||
getattr(model, 'is_loaded_in_8bit', False)
|
||||
or getattr(model, 'is_loaded_in_4bit', False)
|
||||
or getattr(model, 'is_quantized', False)
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit:
|
||||
try: model = model.to('cuda')
|
||||
except Exception as err: raise OpenLLMException(f'Failed to load model into GPU: {err}\n. See https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.') from err
|
||||
try:
|
||||
model = model.to('cuda')
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f'Failed to load model into GPU: {err}.\n') from err
|
||||
if self.has_adapters:
|
||||
logger.debug('Applying the following adapters: %s', self.adapter_map)
|
||||
for adapter_dict in self.adapter_map.values():
|
||||
@@ -314,23 +375,29 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
model.load_adapter(peft_model_id, adapter_name, peft_config=peft_config)
|
||||
self.__llm_model__ = model
|
||||
return self.__llm_model__
|
||||
|
||||
@property
|
||||
def adapter_map(self) -> ResolvedAdapterMap:
|
||||
try:
|
||||
import peft as _ # noqa: F401
|
||||
except ImportError as err:
|
||||
raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'") from err
|
||||
if not self.has_adapters: raise AttributeError('Adapter map is not available.')
|
||||
raise MissingDependencyError(
|
||||
"Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'"
|
||||
) from err
|
||||
if not self.has_adapters:
|
||||
raise AttributeError('Adapter map is not available.')
|
||||
assert self._adapter_map is not None
|
||||
if self.__llm_adapter_map__ is None:
|
||||
_map: ResolvedAdapterMap = {k: {} for k in self._adapter_map}
|
||||
for adapter_type, adapter_tuple in self._adapter_map.items():
|
||||
base = first_not_none(self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type))
|
||||
base = first_not_none(
|
||||
self.config['fine_tune_strategies'].get(adapter_type),
|
||||
default=self.config.make_fine_tune_config(adapter_type),
|
||||
)
|
||||
for adapter in adapter_tuple:
|
||||
_map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
|
||||
self.__llm_adapter_map__ = _map
|
||||
return self.__llm_adapter_map__
|
||||
# yapf: enable
|
||||
|
||||
def prepare_for_training(
|
||||
self, adapter_type: AdapterType = 'lora', use_gradient_checking: bool = True, **attrs: t.Any
|
||||
@@ -475,15 +542,24 @@ def _RunnerFactory(
|
||||
else:
|
||||
system_message = None
|
||||
|
||||
# yapf: disable
|
||||
def _wrapped_repr_keys(_: LLMRunner[M, T]) -> set[str]: return {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}
|
||||
def _wrapped_repr_keys(_: LLMRunner[M, T]) -> set[str]:
|
||||
return {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}
|
||||
|
||||
def _wrapped_repr_args(_: LLMRunner[M, T]) -> ReprArgs:
|
||||
yield 'runner_methods', {method.name: {'batchable': method.config.batchable, 'batch_dim': method.config.batch_dim if method.config.batchable else None} for method in _.runner_methods}
|
||||
yield (
|
||||
'runner_methods',
|
||||
{
|
||||
method.name: {
|
||||
'batchable': method.config.batchable,
|
||||
'batch_dim': method.config.batch_dim if method.config.batchable else None,
|
||||
}
|
||||
for method in _.runner_methods
|
||||
},
|
||||
)
|
||||
yield 'config', self.config.model_dump(flatten=True)
|
||||
yield 'llm_type', self.llm_type
|
||||
yield 'backend', backend
|
||||
yield 'llm_tag', self.tag
|
||||
# yapf: enable
|
||||
|
||||
return types.new_class(
|
||||
self.__class__.__name__ + 'Runner',
|
||||
|
||||
@@ -50,8 +50,8 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
quantization = None
|
||||
if llm._quantise and llm._quantise in {'awq', 'squeezellm'}:
|
||||
quantization = llm._quantise
|
||||
if llm.quantise and llm.quantise in {'awq', 'squeezellm'}:
|
||||
quantization = llm.quantise
|
||||
try:
|
||||
self.model = vllm.AsyncLLMEngine.from_engine_args(
|
||||
vllm.AsyncEngineArgs(
|
||||
@@ -111,7 +111,6 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
self.model = llm.model
|
||||
self.tokenizer = llm.tokenizer
|
||||
self.config = llm.config
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(
|
||||
@@ -155,17 +154,17 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
for i in range(config['max_new_tokens']):
|
||||
if i == 0: # prefill
|
||||
out = self.model(torch.as_tensor([prompt_token_ids], device=self.device), use_cache=True)
|
||||
out = self.model(torch.as_tensor([prompt_token_ids], device=self.model.device), use_cache=True)
|
||||
else: # decoding
|
||||
out = self.model(
|
||||
torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values
|
||||
torch.as_tensor([[token]], device=self.model.device), use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
if logits_processor:
|
||||
if config['repetition_penalty'] > 1.0:
|
||||
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.device)
|
||||
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.model.device)
|
||||
else:
|
||||
tmp_output_ids = None
|
||||
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
||||
@@ -173,7 +172,7 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
last_token_logits = logits[0, -1, :]
|
||||
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
if self.device.type == 'mps':
|
||||
if self.model.device.type == 'mps':
|
||||
last_token_logits = last_token_logits.float().to('cpu')
|
||||
|
||||
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy
|
||||
|
||||
@@ -148,7 +148,7 @@ def construct_docker_options(
|
||||
if llm._prompt_template:
|
||||
env_dict['OPENLLM_PROMPT_TEMPLATE'] = repr(llm._prompt_template.to_string())
|
||||
if quantize:
|
||||
env_dict['OPENLLM_QUANTISE'] = str(quantize)
|
||||
env_dict['OPENLLM_QUANTIZE'] = str(quantize)
|
||||
return DockerOptions(
|
||||
base_image=f'{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}',
|
||||
env=env_dict,
|
||||
|
||||
@@ -621,8 +621,8 @@ def process_environ(
|
||||
'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(),
|
||||
}
|
||||
)
|
||||
if llm._quantise:
|
||||
environ['OPENLLM_QUANTIZE'] = str(llm._quantise)
|
||||
if llm.quantise:
|
||||
environ['OPENLLM_QUANTIZE'] = str(llm.quantise)
|
||||
if system_message:
|
||||
environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template:
|
||||
@@ -650,8 +650,8 @@ def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]
|
||||
|
||||
def build_bento_instruction(llm, model_id, serialisation, adapter_map):
|
||||
cmd_name = f'openllm build {model_id}'
|
||||
if llm._quantise:
|
||||
cmd_name += f' --quantize {llm._quantise}'
|
||||
if llm.quantise:
|
||||
cmd_name += f' --quantize {llm.quantise}'
|
||||
cmd_name += f' --serialization {serialisation}'
|
||||
if adapter_map is not None:
|
||||
cmd_name += ' ' + ' '.join(
|
||||
@@ -994,8 +994,8 @@ def build_command(
|
||||
'OPENLLM_MODEL_ID': llm.model_id,
|
||||
}
|
||||
)
|
||||
if llm._quantise:
|
||||
os.environ['OPENLLM_QUANTIZE'] = str(llm._quantise)
|
||||
if llm.quantise:
|
||||
os.environ['OPENLLM_QUANTIZE'] = str(llm.quantise)
|
||||
if system_message:
|
||||
os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template:
|
||||
|
||||
@@ -24,7 +24,7 @@ if t.TYPE_CHECKING:
|
||||
P = ParamSpec('P')
|
||||
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
|
||||
# NOTE: OpenAI schema
|
||||
LIST_MODEL_SCHEMA = """\
|
||||
LIST_MODELS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -55,14 +55,14 @@ responses:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ModelList'
|
||||
"""
|
||||
CHAT_COMPLETION_SCHEMA = """\
|
||||
CHAT_COMPLETIONS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: >-
|
||||
Given a list of messages comprising a conversation, the model will return a
|
||||
response.
|
||||
operationId: openai__create_chat_completions
|
||||
operationId: openai__chat_completions
|
||||
produces:
|
||||
- application/json
|
||||
tags:
|
||||
@@ -193,7 +193,7 @@ responses:
|
||||
}
|
||||
description: Bad Request
|
||||
"""
|
||||
COMPLETION_SCHEMA = """\
|
||||
COMPLETIONS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -201,7 +201,7 @@ description: >-
|
||||
Given a prompt, the model will return one or more predicted completions, and
|
||||
can also return the probabilities of alternative tokens at each position. We
|
||||
recommend most users use our Chat completions API.
|
||||
operationId: openai__create_completions
|
||||
operationId: openai__completions
|
||||
produces:
|
||||
- application/json
|
||||
tags:
|
||||
@@ -423,15 +423,17 @@ responses:
|
||||
description: Not Found
|
||||
"""
|
||||
|
||||
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
|
||||
|
||||
def add_schema_definitions(append_str: str) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]:
|
||||
def docstring_decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
if func.__doc__ is None:
|
||||
func.__doc__ = ''
|
||||
func.__doc__ = func.__doc__.strip() + '\n\n' + append_str.strip()
|
||||
|
||||
def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
append_str = _SCHEMAS.get(func.__name__.lower(), '')
|
||||
if not append_str:
|
||||
return func
|
||||
|
||||
return docstring_decorator
|
||||
if func.__doc__ is None:
|
||||
func.__doc__ = ''
|
||||
func.__doc__ = func.__doc__.strip() + '\n\n' + append_str.strip()
|
||||
return func
|
||||
|
||||
|
||||
class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
@@ -558,7 +560,7 @@ def append_schemas(
|
||||
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
|
||||
from bentoml._internal.service.openapi.specification import OpenAPISpecification
|
||||
|
||||
svc_schema: t.Any = svc.openapi_spec
|
||||
svc_schema = svc.openapi_spec
|
||||
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
|
||||
svc_schema = svc_schema.asdict()
|
||||
if 'tags' in generated_schema:
|
||||
@@ -572,14 +574,15 @@ def append_schemas(
|
||||
svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
|
||||
svc_schema['paths'].update(generated_schema['paths'])
|
||||
|
||||
from bentoml._internal.service import (
|
||||
openapi, # HACK: mk this attribute until we have a better way to add starlette schemas.
|
||||
)
|
||||
# HACK: mk this attribute until we have a better way to add starlette schemas.
|
||||
from bentoml._internal.service import openapi
|
||||
|
||||
# yapf: disable
|
||||
def mk_generate_spec(svc:bentoml.Service,openapi_version:str=OPENAPI_VERSION)->MKSchema:return MKSchema(svc_schema)
|
||||
def mk_asdict(self:OpenAPISpecification)->dict[str,t.Any]:return svc_schema
|
||||
openapi.generate_spec=mk_generate_spec
|
||||
def mk_generate_spec(svc, openapi_version=OPENAPI_VERSION):
|
||||
return MKSchema(svc_schema)
|
||||
|
||||
def mk_asdict(self):
|
||||
return svc_schema
|
||||
|
||||
openapi.generate_spec = mk_generate_spec
|
||||
OpenAPISpecification.asdict = mk_asdict
|
||||
# yapf: disable
|
||||
return svc
|
||||
|
||||
@@ -14,8 +14,6 @@ from starlette.routing import Route
|
||||
|
||||
from openllm_core.utils import converter
|
||||
|
||||
from ._openapi import HF_ADAPTERS_SCHEMA
|
||||
from ._openapi import HF_AGENT_SCHEMA
|
||||
from ._openapi import add_schema_definitions
|
||||
from ._openapi import append_schemas
|
||||
from ._openapi import get_generator
|
||||
@@ -54,7 +52,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
|
||||
Route('/adapters', endpoint=functools.partial(adapters_map, llm=llm), name='adapters', methods=['GET']),
|
||||
Route('/adapters', endpoint=functools.partial(hf_adapters, llm=llm), name='adapters', methods=['GET']),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
|
||||
],
|
||||
)
|
||||
@@ -71,7 +69,7 @@ def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
@add_schema_definitions(HF_AGENT_SCHEMA)
|
||||
@add_schema_definitions
|
||||
async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
json_str = await req.body()
|
||||
try:
|
||||
@@ -92,8 +90,8 @@ async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Error while generating (Check server log).')
|
||||
|
||||
|
||||
@add_schema_definitions(HF_ADAPTERS_SCHEMA)
|
||||
def adapters_map(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
@add_schema_definitions
|
||||
def hf_adapters(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
if not llm.has_adapters:
|
||||
return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
|
||||
return JSONResponse(
|
||||
|
||||
@@ -18,9 +18,6 @@ from openllm_core._schemas import SampleLogprobs
|
||||
from openllm_core.utils import converter
|
||||
from openllm_core.utils import gen_random_uuid
|
||||
|
||||
from ._openapi import CHAT_COMPLETION_SCHEMA
|
||||
from ._openapi import COMPLETION_SCHEMA
|
||||
from ._openapi import LIST_MODEL_SCHEMA
|
||||
from ._openapi import add_schema_definitions
|
||||
from ._openapi import append_schemas
|
||||
from ._openapi import get_generator
|
||||
@@ -127,8 +124,8 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
|
||||
Route('/completions', functools.partial(create_completions, llm=llm), methods=['POST']),
|
||||
Route('/chat/completions', functools.partial(create_chat_completions, llm=llm), methods=['POST']),
|
||||
Route('/completions', functools.partial(completions, llm=llm), methods=['POST']),
|
||||
Route('/chat/completions', functools.partial(chat_completions, llm=llm), methods=['POST']),
|
||||
],
|
||||
)
|
||||
mount_path = '/v1'
|
||||
@@ -138,7 +135,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
|
||||
|
||||
|
||||
# GET /v1/models
|
||||
@add_schema_definitions(LIST_MODEL_SCHEMA)
|
||||
@add_schema_definitions
|
||||
def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
return JSONResponse(
|
||||
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
|
||||
@@ -146,8 +143,8 @@ def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
|
||||
|
||||
# POST /v1/chat/completions
|
||||
@add_schema_definitions(CHAT_COMPLETION_SCHEMA)
|
||||
async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
@add_schema_definitions
|
||||
async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
# TODO: Check for length based on model context_length
|
||||
json_str = await req.body()
|
||||
try:
|
||||
@@ -263,8 +260,8 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
|
||||
|
||||
|
||||
# POST /v1/completions
|
||||
@add_schema_definitions(COMPLETION_SCHEMA)
|
||||
async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
@add_schema_definitions
|
||||
async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
# TODO: Check for length based on model context_length
|
||||
json_str = await req.body()
|
||||
try:
|
||||
|
||||
@@ -62,7 +62,7 @@ def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLCon
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_patch_correct_tag(llm, config)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize = llm._quantise
|
||||
quantize = llm.quantise
|
||||
safe_serialisation = openllm.utils.first_not_none(
|
||||
attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors'
|
||||
)
|
||||
@@ -132,7 +132,7 @@ def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLCon
|
||||
try:
|
||||
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
if llm._quantization_config or (llm._quantise and llm._quantise not in {'squeezellm', 'awq'}):
|
||||
if llm._quantization_config or (llm.quantise and llm.quantise not in {'squeezellm', 'awq'}):
|
||||
attrs['quantization_config'] = llm.quantization_config
|
||||
if quantize == 'gptq':
|
||||
from optimum.gptq.constants import GPTQ_CONFIG
|
||||
@@ -205,7 +205,7 @@ def check_unintialised_params(model):
|
||||
|
||||
|
||||
def load_model(llm, *decls, **attrs):
|
||||
if llm._quantise in {'awq', 'squeezellm'}:
|
||||
if llm.quantise in {'awq', 'squeezellm'}:
|
||||
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
@@ -217,7 +217,7 @@ def load_model(llm, *decls, **attrs):
|
||||
device_map = 'auto'
|
||||
elif torch.cuda.device_count() == 1:
|
||||
device_map = 'cuda:0'
|
||||
if llm._quantise in {'int8', 'int4'}:
|
||||
if llm.quantise in {'int8', 'int4'}:
|
||||
attrs['quantization_config'] = llm.quantization_config
|
||||
|
||||
if '_quantize' in llm.bentomodel.info.metadata:
|
||||
|
||||
Reference in New Issue
Block a user