mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-27 00:38:04 -05:00
feat(vllm): support passing specific dtype (#626)
* feat(vllm): support passing specific dtype Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * fix: correctly cached the item Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci --------- Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# mypy: disable-error-code="name-defined,attr-defined"
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import types
|
||||
@@ -23,6 +24,7 @@ from openllm_core._typing_compat import (
|
||||
AdapterType,
|
||||
DictStrAny,
|
||||
LiteralBackend,
|
||||
LiteralDtype,
|
||||
LiteralQuantise,
|
||||
LiteralSerialisation,
|
||||
M,
|
||||
@@ -56,6 +58,7 @@ from .exceptions import ForbiddenAttributeError, OpenLLMException
|
||||
from .serialisation.constants import PEFT_CONFIG_NAME
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
import transformers
|
||||
from peft.config import PeftConfig
|
||||
from peft.peft_model import PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM
|
||||
@@ -115,6 +118,19 @@ _reserved_namespace = {'model', 'tokenizer', 'runner', 'import_kwargs'}
|
||||
_AdapterTuple: type[AdapterTuple] = codegen.make_attr_tuple_class('AdapterTuple', ['adapter_id', 'name', 'config'])
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _torch_dtype_mapping():
|
||||
import torch
|
||||
|
||||
return {
|
||||
'half': torch.float16,
|
||||
'float16': torch.float16,
|
||||
'float': torch.float32,
|
||||
'float32': torch.float32,
|
||||
'bfloat16': torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
@attr.define(slots=True, repr=False, init=False)
|
||||
class LLM(t.Generic[M, T], ReprMixin):
|
||||
_model_id: str
|
||||
@@ -131,6 +147,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
_prompt_template: PromptTemplate | None
|
||||
_system_message: str | None
|
||||
|
||||
__llm_torch_dtype__: LiteralDtype | t.Literal['auto', 'half', 'float'] = 'auto'
|
||||
__llm_config__: LLMConfig | None = None
|
||||
__llm_backend__: LiteralBackend = None # type: ignore
|
||||
__llm_quantization_config__: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None
|
||||
@@ -159,6 +176,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
serialisation: LiteralSerialisation = 'safetensors',
|
||||
trust_remote_code: bool = False,
|
||||
embedded: bool = False,
|
||||
torch_dtype: LiteralDtype | t.Literal['auto', 'half', 'float'] = 'auto',
|
||||
**attrs: t.Any,
|
||||
):
|
||||
# low_cpu_mem_usage is only available for model this is helpful on system with low memory to avoid OOM
|
||||
@@ -203,6 +221,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
system_message=system_message,
|
||||
LLM__model_attrs=model_attrs,
|
||||
LLM__tokenizer_attrs=tokenizer_attrs,
|
||||
llm_torch_dtype__=torch_dtype.lower(),
|
||||
llm_backend__=backend,
|
||||
llm_config__=llm_config,
|
||||
llm_trust_remote_code__=trust_remote_code,
|
||||
@@ -223,6 +242,31 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
logger.info("To disable this warning, set 'OPENLLM_DISABLE_WARNING=True'")
|
||||
self.runner.init_local(quiet=True)
|
||||
|
||||
@property
|
||||
def _torch_dtype(self) -> torch.dtype:
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
if not isinstance(self.__llm_torch_dtype__, torch.dtype):
|
||||
config_dtype = getattr(
|
||||
transformers.AutoConfig.from_pretrained(self.bentomodel.path, trust_remote_code=self.trust_remote_code),
|
||||
'torch_dtype',
|
||||
None,
|
||||
)
|
||||
if config_dtype is None:
|
||||
config_type = torch.float32
|
||||
if self.__llm_torch_dtype__ == 'auto':
|
||||
if config_dtype == torch.float32:
|
||||
torch_dtype = torch.float16 # following common practice
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if self.__llm_torch_dtype__ not in _torch_dtype_mapping():
|
||||
raise ValueError(f"Unknown dtype '{self.__llm_torch_dtype__}'")
|
||||
torch_dtype = _torch_dtype_mapping()[self.__llm_torch_dtype__]
|
||||
self.__llm_torch_dtype__ = torch_dtype
|
||||
return self.__llm_torch_dtype__
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self, model_id, model_version, backend) -> tuple[str, str | None]:
|
||||
model_id, *maybe_revision = model_id.rsplit(':')
|
||||
@@ -268,10 +312,10 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
|
||||
return {
|
||||
'device_map': 'auto' if torch.cuda.is_available() else None,
|
||||
'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() else None, 'torch_dtype': self._torch_dtype}, {
|
||||
'padding_side': 'left',
|
||||
'truncation_side': 'left',
|
||||
}
|
||||
|
||||
@property
|
||||
def trust_remote_code(self) -> bool:
|
||||
|
||||
@@ -11,7 +11,7 @@ import openllm
|
||||
from openllm.exceptions import OpenLLMException
|
||||
from openllm_core._schemas import CompletionChunk, GenerationOutput
|
||||
from openllm_core._typing_compat import LiteralBackend, M, T
|
||||
from openllm_core.utils import first_not_none, get_debug_mode, is_vllm_available
|
||||
from openllm_core.utils import first_not_none, is_vllm_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import vllm
|
||||
@@ -53,9 +53,8 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus,
|
||||
dtype='auto',
|
||||
dtype=str(llm._torch_dtype).split('.')[-1],
|
||||
quantization=quantization,
|
||||
disable_log_requests=not get_debug_mode(),
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user