chore(llm): expose quantise and lazy load heavy imports (#617)

* chore(llm): expose quantise and lazy load heavy imports

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: move transformers to TYPE_CHECKING block

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-12 14:55:37 -05:00
committed by GitHub
parent 106e8617c1
commit 7e1fb35a71
8 changed files with 189 additions and 116 deletions

View File

@@ -9,8 +9,6 @@ import typing as t
import attr
import inflection
import orjson
import torch
import transformers
from huggingface_hub import hf_hub_download
@@ -58,6 +56,7 @@ from .serialisation.constants import PEFT_CONFIG_NAME
if t.TYPE_CHECKING:
import peft
import transformers
from bentoml._internal.runner.runnable import RunnableMethod
from bentoml._internal.runner.runner import RunnerMethod
@@ -124,8 +123,8 @@ class LLM(t.Generic[M, T], ReprMixin):
_quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None
_quantise: LiteralQuantise | None
_model_decls: TupleAny
_model_attrs: DictStrAny
_tokenizer_attrs: DictStrAny
__model_attrs: DictStrAny
__tokenizer_attrs: DictStrAny
_tag: bentoml.Tag
_adapter_map: AdapterMap | None
_serialisation: LiteralSerialisation
@@ -133,7 +132,6 @@ class LLM(t.Generic[M, T], ReprMixin):
_prompt_template: PromptTemplate | None
_system_message: str | None
_bentomodel: bentoml.Model = attr.field(init=False)
__llm_config__: LLMConfig | None = None
__llm_backend__: LiteralBackend = None # type: ignore
__llm_quantization_config__: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None
@@ -168,16 +166,11 @@ class LLM(t.Generic[M, T], ReprMixin):
_local = False
if validate_is_path(model_id):
model_id, _local = resolve_filepath(model_id), True
backend = t.cast(
LiteralBackend,
first_not_none(
backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if openllm.utils.is_vllm_available() else 'pt'
),
)
quantize = first_not_none(
quantize, t.cast(t.Optional[LiteralQuantise], os.getenv('OPENLLM_QUANTIZE')), default=None
backend = first_not_none(
backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if openllm.utils.is_vllm_available() else 'pt'
)
quantize = first_not_none(quantize, os.getenv('OPENLLM_QUANTIZE'), default=None)
# elif quantization_config is None and quantize is not None:
# quantization_config, attrs = infer_quantisation_config(self, quantize, **attrs)
attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage})
@@ -199,17 +192,17 @@ class LLM(t.Generic[M, T], ReprMixin):
self.__attrs_init__(
model_id=model_id,
revision=model_version,
tag=bentoml.Tag.from_taglike(t.cast(t.Union[str, bentoml.Tag], model_tag)),
tag=bentoml.Tag.from_taglike(model_tag),
quantization_config=quantization_config,
quantise=quantize,
model_decls=args,
model_attrs=dict(**self.import_kwargs[0], **model_attrs),
tokenizer_attrs=dict(**self.import_kwargs[-1], **tokenizer_attrs),
adapter_map=resolve_peft_config_type(adapter_map) if adapter_map is not None else None,
serialisation=serialisation,
local=_local,
prompt_template=prompt_template,
system_message=system_message,
LLM__model_attrs=model_attrs,
LLM__tokenizer_attrs=tokenizer_attrs,
llm_backend__=backend,
llm_config__=llm_config,
llm_trust_remote_code__=trust_remote_code,
@@ -221,7 +214,6 @@ class LLM(t.Generic[M, T], ReprMixin):
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
# resolve the tag
self._tag = model.tag
self._bentomodel = model
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
def _make_tag_components(self, model_id, model_version, backend) -> tuple[str, str | None]:
@@ -241,72 +233,141 @@ class LLM(t.Generic[M, T], ReprMixin):
)
return f'{backend}-{normalise_model_name(model_id)}', model_version
# yapf: disable
def __setattr__(self,attr,value):
if attr in _reserved_namespace:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
super().__setattr__(attr,value)
def __setattr__(self, attr, value):
if attr in _reserved_namespace:
raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
super().__setattr__(attr, value)
@property
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
def _model_attrs(self) -> dict[str, t.Any]:
return {**self.import_kwargs[0], **self.__model_attrs}
@property
def _tokenizer_attrs(self) -> dict[str, t.Any]:
return {**self.import_kwargs[1], **self.__tokenizer_attrs}
@property
def __repr_keys__(self):
return {'model_id', 'revision', 'backend', 'type'}
def __repr_args__(self):
yield 'model_id',self._model_id if not self._local else self.tag.name
yield 'revision',self._revision if self._revision else self.tag.version
yield 'backend',self.__llm_backend__
yield 'type',self.llm_type
yield 'model_id', self._model_id if not self._local else self.tag.name
yield 'revision', self._revision if self._revision else self.tag.version
yield 'backend', self.__llm_backend__
yield 'type', self.llm_type
@property
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
import torch
return {
'device_map': 'auto' if torch.cuda.is_available() else None,
'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32,
}, {'padding_side': 'left', 'truncation_side': 'left'}
@property
def trust_remote_code(self)->bool:return first_not_none(check_bool_env('TRUST_REMOTE_CODE',False),default=self.__llm_trust_remote_code__)
def trust_remote_code(self) -> bool:
return first_not_none(check_bool_env('TRUST_REMOTE_CODE', False), default=self.__llm_trust_remote_code__)
@property
def runner_name(self)->str:return f"llm-{self.config['start_name']}-runner"
def runner_name(self) -> str:
return f"llm-{self.config['start_name']}-runner"
@property
def model_id(self)->str:return self._model_id
def model_id(self) -> str:
return self._model_id
@property
def revision(self)->str:return t.cast(str, self._revision)
def revision(self) -> str:
return t.cast(str, self._revision)
@property
def tag(self)->bentoml.Tag:return self._tag
def tag(self) -> bentoml.Tag:
return self._tag
@property
def bentomodel(self)->bentoml.Model:return openllm.serialisation.get(self)
def bentomodel(self) -> bentoml.Model:
return openllm.serialisation.get(self)
@property
def quantization_config(self)->transformers.BitsAndBytesConfig|transformers.GPTQConfig|transformers.AwqConfig:
def quantization_config(self) -> transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig:
if self.__llm_quantization_config__ is None:
if self._quantization_config is not None:self.__llm_quantization_config__=self._quantization_config
elif self._quantise is not None:self.__llm_quantization_config__,self._model_attrs=infer_quantisation_config(self, self._quantise, **self._model_attrs)
else:raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
if self._quantization_config is not None:
self.__llm_quantization_config__ = self._quantization_config
elif self._quantise is not None:
self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(
self, self._quantise, **self._model_attrs
)
else:
raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
return self.__llm_quantization_config__
@property
def has_adapters(self)->bool:return self._adapter_map is not None
def has_adapters(self) -> bool:
return self._adapter_map is not None
@property
def local(self)->bool:return self._local
def local(self) -> bool:
return self._local
@property
def quantise(self) -> LiteralQuantise | None:
return self._quantise
# NOTE: The section below defines a loose contract with langchain's LLM interface.
@property
def llm_type(self)->str:return normalise_model_name(self._model_id)
def llm_type(self) -> str:
return normalise_model_name(self._model_id)
@property
def identifying_params(self)->DictStrAny:return {'configuration': self.config.model_dump_json().decode(),'model_ids': orjson.dumps(self.config['model_ids']).decode(),'model_id': self.model_id}
def identifying_params(self) -> DictStrAny:
return {
'configuration': self.config.model_dump_json().decode(),
'model_ids': orjson.dumps(self.config['model_ids']).decode(),
'model_id': self.model_id,
}
@property
def llm_parameters(self)->tuple[tuple[tuple[t.Any,...],DictStrAny],DictStrAny]:return (self._model_decls,self._model_attrs),self._tokenizer_attrs
def llm_parameters(self) -> tuple[tuple[tuple[t.Any, ...], DictStrAny], DictStrAny]:
return (self._model_decls, self._model_attrs), self._tokenizer_attrs
# NOTE: This section is the actual model, tokenizer, and config reference here.
@property
def config(self)->LLMConfig:
if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
def config(self) -> LLMConfig:
if self.__llm_config__ is None:
self.__llm_config__ = openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
return self.__llm_config__
@property
def tokenizer(self)->T:
if self.__llm_tokenizer__ is None:self.__llm_tokenizer__=openllm.serialisation.load_tokenizer(self,**self.llm_parameters[-1])
def tokenizer(self) -> T:
if self.__llm_tokenizer__ is None:
self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
return self.__llm_tokenizer__
@property
def runner(self)->LLMRunner[M, T]:
if self.__llm_runner__ is None:self.__llm_runner__=_RunnerFactory(self)
def runner(self) -> LLMRunner[M, T]:
if self.__llm_runner__ is None:
self.__llm_runner__ = _RunnerFactory(self)
return self.__llm_runner__
@property
def model(self)->M:
def model(self) -> M:
if self.__llm_model__ is None:
model=openllm.serialisation.load_model(self,*self._model_decls,**self._model_attrs)
model = openllm.serialisation.load_model(self, *self._model_decls, **self._model_attrs)
# If OOM, then it is probably you don't have enough VRAM to run this model.
if self.__llm_backend__ == 'pt':
loaded_in_kbit = getattr(model,'is_loaded_in_8bit',False) or getattr(model,'is_loaded_in_4bit',False) or getattr(model,'is_quantized',False)
import torch
loaded_in_kbit = (
getattr(model, 'is_loaded_in_8bit', False)
or getattr(model, 'is_loaded_in_4bit', False)
or getattr(model, 'is_quantized', False)
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit:
try: model = model.to('cuda')
except Exception as err: raise OpenLLMException(f'Failed to load model into GPU: {err}\n. See https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.') from err
try:
model = model.to('cuda')
except Exception as err:
raise OpenLLMException(f'Failed to load model into GPU: {err}.\n') from err
if self.has_adapters:
logger.debug('Applying the following adapters: %s', self.adapter_map)
for adapter_dict in self.adapter_map.values():
@@ -314,23 +375,29 @@ class LLM(t.Generic[M, T], ReprMixin):
model.load_adapter(peft_model_id, adapter_name, peft_config=peft_config)
self.__llm_model__ = model
return self.__llm_model__
@property
def adapter_map(self) -> ResolvedAdapterMap:
try:
import peft as _ # noqa: F401
except ImportError as err:
raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'") from err
if not self.has_adapters: raise AttributeError('Adapter map is not available.')
raise MissingDependencyError(
"Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'"
) from err
if not self.has_adapters:
raise AttributeError('Adapter map is not available.')
assert self._adapter_map is not None
if self.__llm_adapter_map__ is None:
_map: ResolvedAdapterMap = {k: {} for k in self._adapter_map}
for adapter_type, adapter_tuple in self._adapter_map.items():
base = first_not_none(self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type))
base = first_not_none(
self.config['fine_tune_strategies'].get(adapter_type),
default=self.config.make_fine_tune_config(adapter_type),
)
for adapter in adapter_tuple:
_map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
self.__llm_adapter_map__ = _map
return self.__llm_adapter_map__
# yapf: enable
def prepare_for_training(
self, adapter_type: AdapterType = 'lora', use_gradient_checking: bool = True, **attrs: t.Any
@@ -475,15 +542,24 @@ def _RunnerFactory(
else:
system_message = None
# yapf: disable
def _wrapped_repr_keys(_: LLMRunner[M, T]) -> set[str]: return {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}
def _wrapped_repr_keys(_: LLMRunner[M, T]) -> set[str]:
return {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}
def _wrapped_repr_args(_: LLMRunner[M, T]) -> ReprArgs:
yield 'runner_methods', {method.name: {'batchable': method.config.batchable, 'batch_dim': method.config.batch_dim if method.config.batchable else None} for method in _.runner_methods}
yield (
'runner_methods',
{
method.name: {
'batchable': method.config.batchable,
'batch_dim': method.config.batch_dim if method.config.batchable else None,
}
for method in _.runner_methods
},
)
yield 'config', self.config.model_dump(flatten=True)
yield 'llm_type', self.llm_type
yield 'backend', backend
yield 'llm_tag', self.tag
# yapf: enable
return types.new_class(
self.__class__.__name__ + 'Runner',

View File

@@ -50,8 +50,8 @@ class vLLMRunnable(bentoml.Runnable):
if dev >= 2:
num_gpus = min(dev // 2 * 2, dev)
quantization = None
if llm._quantise and llm._quantise in {'awq', 'squeezellm'}:
quantization = llm._quantise
if llm.quantise and llm.quantise in {'awq', 'squeezellm'}:
quantization = llm.quantise
try:
self.model = vllm.AsyncLLMEngine.from_engine_args(
vllm.AsyncEngineArgs(
@@ -111,7 +111,6 @@ class PyTorchRunnable(bentoml.Runnable):
self.model = llm.model
self.tokenizer = llm.tokenizer
self.config = llm.config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(
@@ -155,17 +154,17 @@ class PyTorchRunnable(bentoml.Runnable):
finish_reason: t.Optional[FinishReason] = None
for i in range(config['max_new_tokens']):
if i == 0: # prefill
out = self.model(torch.as_tensor([prompt_token_ids], device=self.device), use_cache=True)
out = self.model(torch.as_tensor([prompt_token_ids], device=self.model.device), use_cache=True)
else: # decoding
out = self.model(
torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values
torch.as_tensor([[token]], device=self.model.device), use_cache=True, past_key_values=past_key_values
)
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
if config['repetition_penalty'] > 1.0:
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.device)
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.model.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
@@ -173,7 +172,7 @@ class PyTorchRunnable(bentoml.Runnable):
last_token_logits = logits[0, -1, :]
# Switch to CPU by avoiding some bugs in mps backend.
if self.device.type == 'mps':
if self.model.device.type == 'mps':
last_token_logits = last_token_logits.float().to('cpu')
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy

View File

@@ -148,7 +148,7 @@ def construct_docker_options(
if llm._prompt_template:
env_dict['OPENLLM_PROMPT_TEMPLATE'] = repr(llm._prompt_template.to_string())
if quantize:
env_dict['OPENLLM_QUANTISE'] = str(quantize)
env_dict['OPENLLM_QUANTIZE'] = str(quantize)
return DockerOptions(
base_image=f'{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}',
env=env_dict,

View File

@@ -621,8 +621,8 @@ def process_environ(
'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(),
}
)
if llm._quantise:
environ['OPENLLM_QUANTIZE'] = str(llm._quantise)
if llm.quantise:
environ['OPENLLM_QUANTIZE'] = str(llm.quantise)
if system_message:
environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template:
@@ -650,8 +650,8 @@ def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]
def build_bento_instruction(llm, model_id, serialisation, adapter_map):
cmd_name = f'openllm build {model_id}'
if llm._quantise:
cmd_name += f' --quantize {llm._quantise}'
if llm.quantise:
cmd_name += f' --quantize {llm.quantise}'
cmd_name += f' --serialization {serialisation}'
if adapter_map is not None:
cmd_name += ' ' + ' '.join(
@@ -994,8 +994,8 @@ def build_command(
'OPENLLM_MODEL_ID': llm.model_id,
}
)
if llm._quantise:
os.environ['OPENLLM_QUANTIZE'] = str(llm._quantise)
if llm.quantise:
os.environ['OPENLLM_QUANTIZE'] = str(llm.quantise)
if system_message:
os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template:

View File

@@ -24,7 +24,7 @@ if t.TYPE_CHECKING:
P = ParamSpec('P')
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
# NOTE: OpenAI schema
LIST_MODEL_SCHEMA = """\
LIST_MODELS_SCHEMA = """\
---
consumes:
- application/json
@@ -55,14 +55,14 @@ responses:
schema:
$ref: '#/components/schemas/ModelList'
"""
CHAT_COMPLETION_SCHEMA = """\
CHAT_COMPLETIONS_SCHEMA = """\
---
consumes:
- application/json
description: >-
Given a list of messages comprising a conversation, the model will return a
response.
operationId: openai__create_chat_completions
operationId: openai__chat_completions
produces:
- application/json
tags:
@@ -193,7 +193,7 @@ responses:
}
description: Bad Request
"""
COMPLETION_SCHEMA = """\
COMPLETIONS_SCHEMA = """\
---
consumes:
- application/json
@@ -201,7 +201,7 @@ description: >-
Given a prompt, the model will return one or more predicted completions, and
can also return the probabilities of alternative tokens at each position. We
recommend most users use our Chat completions API.
operationId: openai__create_completions
operationId: openai__completions
produces:
- application/json
tags:
@@ -423,15 +423,17 @@ responses:
description: Not Found
"""
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
def add_schema_definitions(append_str: str) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]:
def docstring_decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
if func.__doc__ is None:
func.__doc__ = ''
func.__doc__ = func.__doc__.strip() + '\n\n' + append_str.strip()
def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
append_str = _SCHEMAS.get(func.__name__.lower(), '')
if not append_str:
return func
return docstring_decorator
if func.__doc__ is None:
func.__doc__ = ''
func.__doc__ = func.__doc__.strip() + '\n\n' + append_str.strip()
return func
class OpenLLMSchemaGenerator(SchemaGenerator):
@@ -558,7 +560,7 @@ def append_schemas(
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
from bentoml._internal.service.openapi.specification import OpenAPISpecification
svc_schema: t.Any = svc.openapi_spec
svc_schema = svc.openapi_spec
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
svc_schema = svc_schema.asdict()
if 'tags' in generated_schema:
@@ -572,14 +574,15 @@ def append_schemas(
svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
svc_schema['paths'].update(generated_schema['paths'])
from bentoml._internal.service import (
openapi, # HACK: mk this attribute until we have a better way to add starlette schemas.
)
# HACK: mk this attribute until we have a better way to add starlette schemas.
from bentoml._internal.service import openapi
# yapf: disable
def mk_generate_spec(svc:bentoml.Service,openapi_version:str=OPENAPI_VERSION)->MKSchema:return MKSchema(svc_schema)
def mk_asdict(self:OpenAPISpecification)->dict[str,t.Any]:return svc_schema
openapi.generate_spec=mk_generate_spec
def mk_generate_spec(svc, openapi_version=OPENAPI_VERSION):
return MKSchema(svc_schema)
def mk_asdict(self):
return svc_schema
openapi.generate_spec = mk_generate_spec
OpenAPISpecification.asdict = mk_asdict
# yapf: disable
return svc

View File

@@ -14,8 +14,6 @@ from starlette.routing import Route
from openllm_core.utils import converter
from ._openapi import HF_ADAPTERS_SCHEMA
from ._openapi import HF_AGENT_SCHEMA
from ._openapi import add_schema_definitions
from ._openapi import append_schemas
from ._openapi import get_generator
@@ -54,7 +52,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
debug=True,
routes=[
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
Route('/adapters', endpoint=functools.partial(adapters_map, llm=llm), name='adapters', methods=['GET']),
Route('/adapters', endpoint=functools.partial(hf_adapters, llm=llm), name='adapters', methods=['GET']),
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
],
)
@@ -71,7 +69,7 @@ def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
)
@add_schema_definitions(HF_AGENT_SCHEMA)
@add_schema_definitions
async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
json_str = await req.body()
try:
@@ -92,8 +90,8 @@ async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Error while generating (Check server log).')
@add_schema_definitions(HF_ADAPTERS_SCHEMA)
def adapters_map(req: Request, llm: openllm.LLM[M, T]) -> Response:
@add_schema_definitions
def hf_adapters(req: Request, llm: openllm.LLM[M, T]) -> Response:
if not llm.has_adapters:
return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
return JSONResponse(

View File

@@ -18,9 +18,6 @@ from openllm_core._schemas import SampleLogprobs
from openllm_core.utils import converter
from openllm_core.utils import gen_random_uuid
from ._openapi import CHAT_COMPLETION_SCHEMA
from ._openapi import COMPLETION_SCHEMA
from ._openapi import LIST_MODEL_SCHEMA
from ._openapi import add_schema_definitions
from ._openapi import append_schemas
from ._openapi import get_generator
@@ -127,8 +124,8 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
debug=True,
routes=[
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
Route('/completions', functools.partial(create_completions, llm=llm), methods=['POST']),
Route('/chat/completions', functools.partial(create_chat_completions, llm=llm), methods=['POST']),
Route('/completions', functools.partial(completions, llm=llm), methods=['POST']),
Route('/chat/completions', functools.partial(chat_completions, llm=llm), methods=['POST']),
],
)
mount_path = '/v1'
@@ -138,7 +135,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
# GET /v1/models
@add_schema_definitions(LIST_MODEL_SCHEMA)
@add_schema_definitions
def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
return JSONResponse(
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
@@ -146,8 +143,8 @@ def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
# POST /v1/chat/completions
@add_schema_definitions(CHAT_COMPLETION_SCHEMA)
async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
@add_schema_definitions
async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
# TODO: Check for length based on model context_length
json_str = await req.body()
try:
@@ -263,8 +260,8 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
# POST /v1/completions
@add_schema_definitions(COMPLETION_SCHEMA)
async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
@add_schema_definitions
async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
# TODO: Check for length based on model context_length
json_str = await req.body()
try:

View File

@@ -62,7 +62,7 @@ def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLCon
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
_patch_correct_tag(llm, config)
_, tokenizer_attrs = llm.llm_parameters
quantize = llm._quantise
quantize = llm.quantise
safe_serialisation = openllm.utils.first_not_none(
attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors'
)
@@ -132,7 +132,7 @@ def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLCon
try:
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
tokenizer.save_pretrained(bentomodel.path)
if llm._quantization_config or (llm._quantise and llm._quantise not in {'squeezellm', 'awq'}):
if llm._quantization_config or (llm.quantise and llm.quantise not in {'squeezellm', 'awq'}):
attrs['quantization_config'] = llm.quantization_config
if quantize == 'gptq':
from optimum.gptq.constants import GPTQ_CONFIG
@@ -205,7 +205,7 @@ def check_unintialised_params(model):
def load_model(llm, *decls, **attrs):
if llm._quantise in {'awq', 'squeezellm'}:
if llm.quantise in {'awq', 'squeezellm'}:
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
config, attrs = transformers.AutoConfig.from_pretrained(
llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs
@@ -217,7 +217,7 @@ def load_model(llm, *decls, **attrs):
device_map = 'auto'
elif torch.cuda.device_count() == 1:
device_map = 'cuda:0'
if llm._quantise in {'int8', 'int4'}:
if llm.quantise in {'int8', 'int4'}:
attrs['quantization_config'] = llm.quantization_config
if '_quantize' in llm.bentomodel.info.metadata: