From 22eaaf3ce1f5d1f485f1fece86febf31dca8ad31 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Mon, 13 Nov 2023 05:08:33 -0500 Subject: [PATCH] feat(vllm): support passing specific dtype (#626) * feat(vllm): support passing specific dtype Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * fix: correctly cached the item Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci --------- Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../src/openllm_core/_typing_compat.py | 1 + openllm-python/src/openllm/_llm.py | 52 +++++++++++++++++-- openllm-python/src/openllm/_runners.py | 5 +- 3 files changed, 51 insertions(+), 7 deletions(-) diff --git a/openllm-core/src/openllm_core/_typing_compat.py b/openllm-core/src/openllm_core/_typing_compat.py index ccac4483..c926635f 100644 --- a/openllm-core/src/openllm_core/_typing_compat.py +++ b/openllm-core/src/openllm_core/_typing_compat.py @@ -30,6 +30,7 @@ ListStr = t.List[str] TupleAny = t.Tuple[t.Any, ...] At = t.TypeVar('At', bound=attr.AttrsInstance) +LiteralDtype = t.Literal['float16', 'float32', 'bfloat16'] LiteralSerialisation = t.Literal['safetensors', 'legacy'] LiteralQuantise = t.Literal['int8', 'int4', 'gptq', 'awq', 'squeezellm'] LiteralBackend = t.Literal['pt', 'vllm', 'ggml', 'mlc'] diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index abf71df9..887ed3f1 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -1,6 +1,7 @@ # mypy: disable-error-code="name-defined,attr-defined" from __future__ import annotations import abc +import functools import logging import os import types @@ -23,6 +24,7 @@ from openllm_core._typing_compat import ( AdapterType, DictStrAny, LiteralBackend, + LiteralDtype, LiteralQuantise, LiteralSerialisation, M, @@ -56,6 +58,7 @@ from .exceptions import ForbiddenAttributeError, OpenLLMException from .serialisation.constants import PEFT_CONFIG_NAME if t.TYPE_CHECKING: + import torch import transformers from peft.config import PeftConfig from peft.peft_model import PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM @@ -115,6 +118,19 @@ _reserved_namespace = {'model', 'tokenizer', 'runner', 'import_kwargs'} _AdapterTuple: type[AdapterTuple] = codegen.make_attr_tuple_class('AdapterTuple', ['adapter_id', 'name', 'config']) +@functools.lru_cache(maxsize=1) +def _torch_dtype_mapping(): + import torch + + return { + 'half': torch.float16, + 'float16': torch.float16, + 'float': torch.float32, + 'float32': torch.float32, + 'bfloat16': torch.bfloat16, + } + + @attr.define(slots=True, repr=False, init=False) class LLM(t.Generic[M, T], ReprMixin): _model_id: str @@ -131,6 +147,7 @@ class LLM(t.Generic[M, T], ReprMixin): _prompt_template: PromptTemplate | None _system_message: str | None + __llm_torch_dtype__: LiteralDtype | t.Literal['auto', 'half', 'float'] = 'auto' __llm_config__: LLMConfig | None = None __llm_backend__: LiteralBackend = None # type: ignore __llm_quantization_config__: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None @@ -159,6 +176,7 @@ class LLM(t.Generic[M, T], ReprMixin): serialisation: LiteralSerialisation = 'safetensors', trust_remote_code: bool = False, embedded: bool = False, + torch_dtype: LiteralDtype | t.Literal['auto', 'half', 'float'] = 'auto', **attrs: t.Any, ): # low_cpu_mem_usage is only available for model this is helpful on system with low memory to avoid OOM @@ -203,6 +221,7 @@ class LLM(t.Generic[M, T], ReprMixin): system_message=system_message, LLM__model_attrs=model_attrs, LLM__tokenizer_attrs=tokenizer_attrs, + llm_torch_dtype__=torch_dtype.lower(), llm_backend__=backend, llm_config__=llm_config, llm_trust_remote_code__=trust_remote_code, @@ -223,6 +242,31 @@ class LLM(t.Generic[M, T], ReprMixin): logger.info("To disable this warning, set 'OPENLLM_DISABLE_WARNING=True'") self.runner.init_local(quiet=True) + @property + def _torch_dtype(self) -> torch.dtype: + import torch + import transformers + + if not isinstance(self.__llm_torch_dtype__, torch.dtype): + config_dtype = getattr( + transformers.AutoConfig.from_pretrained(self.bentomodel.path, trust_remote_code=self.trust_remote_code), + 'torch_dtype', + None, + ) + if config_dtype is None: + config_type = torch.float32 + if self.__llm_torch_dtype__ == 'auto': + if config_dtype == torch.float32: + torch_dtype = torch.float16 # following common practice + else: + torch_dtype = config_dtype + else: + if self.__llm_torch_dtype__ not in _torch_dtype_mapping(): + raise ValueError(f"Unknown dtype '{self.__llm_torch_dtype__}'") + torch_dtype = _torch_dtype_mapping()[self.__llm_torch_dtype__] + self.__llm_torch_dtype__ = torch_dtype + return self.__llm_torch_dtype__ + @apply(lambda val: tuple(str.lower(i) if i else i for i in val)) def _make_tag_components(self, model_id, model_version, backend) -> tuple[str, str | None]: model_id, *maybe_revision = model_id.rsplit(':') @@ -268,10 +312,10 @@ class LLM(t.Generic[M, T], ReprMixin): def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: import torch - return { - 'device_map': 'auto' if torch.cuda.is_available() else None, - 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32, - }, {'padding_side': 'left', 'truncation_side': 'left'} + return {'device_map': 'auto' if torch.cuda.is_available() else None, 'torch_dtype': self._torch_dtype}, { + 'padding_side': 'left', + 'truncation_side': 'left', + } @property def trust_remote_code(self) -> bool: diff --git a/openllm-python/src/openllm/_runners.py b/openllm-python/src/openllm/_runners.py index 1bb4a7d2..ddbdb972 100644 --- a/openllm-python/src/openllm/_runners.py +++ b/openllm-python/src/openllm/_runners.py @@ -11,7 +11,7 @@ import openllm from openllm.exceptions import OpenLLMException from openllm_core._schemas import CompletionChunk, GenerationOutput from openllm_core._typing_compat import LiteralBackend, M, T -from openllm_core.utils import first_not_none, get_debug_mode, is_vllm_available +from openllm_core.utils import first_not_none, is_vllm_available if t.TYPE_CHECKING: import vllm @@ -53,9 +53,8 @@ class vLLMRunnable(bentoml.Runnable): trust_remote_code=llm.trust_remote_code, tokenizer_mode='auto', tensor_parallel_size=num_gpus, - dtype='auto', + dtype=str(llm._torch_dtype).split('.')[-1], quantization=quantization, - disable_log_requests=not get_debug_mode(), worker_use_ray=False, engine_use_ray=False, )