From 73d152fc776deef28f04fcfe5278657aae2f3c80 Mon Sep 17 00:00:00 2001 From: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Date: Thu, 25 May 2023 02:10:36 +0000 Subject: [PATCH] feat(gpu): Make sure that we run models on GPU if available Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- pyproject.toml | 2 +- src/openllm/_configuration.py | 3 ++ src/openllm/_llm.py | 52 +++++++++++++++---- src/openllm/{types.py => _types.py} | 0 .../models/chatglm/configuration_chatglm.py | 2 +- .../models/chatglm/modeling_chatglm.py | 35 +++++++------ .../models/dolly_v2/modeling_dolly_v2.py | 29 ++++++----- 7 files changed, 85 insertions(+), 38 deletions(-) rename src/openllm/{types.py => _types.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index 157fe18d..b9547f8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ dependencies = [ # bentoml[grpc,grpc-reflection] include grpcio, grpcio-reflection "bentoml[io-image,io-pandas,io-file,grpc,grpc-reflection]>=1.0.19", # bentoml[torch] includes torch and transformers - "transformers[torch]>=4.29.0", + "transformers[torch,accelerate,tokenizers,onnxruntime,onnx]>=4.29.0", # Super fast JSON serialization "orjson", "inflection", diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index 8c765ad0..0c1d33ac 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -431,6 +431,7 @@ class LLMConfig(pydantic.BaseModel, ABC): __openllm_start_name__: str = "" __openllm_timeout__: int = 0 __openllm_name_type__: t.Literal["dasherize", "lowercase"] = "dasherize" + __openllm_trust_remote_code__: bool = False GenerationConfig: type[t.Any] = GenerationConfig def __init_subclass__( @@ -438,6 +439,7 @@ class LLMConfig(pydantic.BaseModel, ABC): *, default_timeout: int | None = None, name_type: t.Literal["dasherize", "lowercase"] = "dasherize", + trust_remote_code: bool = False, **kwargs: t.Any, ): if default_timeout is None: @@ -446,6 +448,7 @@ class LLMConfig(pydantic.BaseModel, ABC): if name_type not in ("dasherize", "lowercase"): raise RuntimeError(f"Unknown name_type {name_type}. Only allowed are 'dasherize' and 'lowercase'.") cls.__openllm_name_type__ = name_type + cls.__openllm_trust_remote_code__ = trust_remote_code super(LLMConfig, cls).__init_subclass__(**kwargs) diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 5caae2ba..012dd492 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -35,7 +35,7 @@ if t.TYPE_CHECKING: import transformers from bentoml._internal.runner.strategy import Strategy - from .types import LLMModel, LLMTokenizer, ModelSignatureDict + from ._types import LLMModel, LLMTokenizer, ModelSignatureDict else: ModelSignatureDict = dict transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") @@ -296,7 +296,8 @@ class LLM(LLMInterface): dolly_v2_runner = openllm.Runner("dolly-v2", _tokenizer_padding_size="left", torch_dtype=torch.bfloat8, device_map='gpu') ``` - Note: If you implement your own `import_model`, then `import_kwargs` will be ignored. + Note: If you implement your own `import_model`, then `import_kwargs` will be the default kwargs for every load. You can still override those + via ``openllm.Runner``. Note that this tag will be generated based on `self.default_model` or the given `pretrained` kwds. passed from the __init__ constructor. @@ -334,11 +335,26 @@ class LLM(LLMInterface): @property def _bentomodel(self) -> bentoml.Model: if self.__bentomodel__ is None: - tag, kwds = openllm.utils.generate_tags(self._pretrained, prefix=self._implementation, **self._kwargs) + trust_remote_code = self._kwargs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__) + tag, kwds = openllm.utils.generate_tags( + self._pretrained, prefix=self._implementation, trust_remote_code=trust_remote_code, **self._kwargs + ) tokenizer_kwds = {k[len("_tokenizer_") :]: v for k, v in kwds.items() if k.startswith("_tokenizer_")} kwds = {k: v for k, v in kwds.items() if not k.startswith("_tokenizer_")} + if self.import_kwargs: + tokenizer_kwds = { + **{ + k[len("_tokenizer_") :]: v for k, v in self.import_kwargs.items() if k.startswith("_tokenizer_") + }, + **tokenizer_kwds, + } + kwds = { + **{k: v for k, v in self.import_kwargs.items() if not k.startswith("_tokenizer_")}, + **kwds, + } + try: self.__bentomodel__ = bentoml.transformers.get(tag) except bentoml.exceptions.BentoMLException: @@ -348,17 +364,21 @@ class LLM(LLMInterface): if hasattr(self, "import_model"): logger.debug("Using custom 'import_model' defined in subclass.") self.__bentomodel__ = self.import_model( - self._pretrained, tag, *self._args, tokenizer_kwds=tokenizer_kwds, **kwds + self._pretrained, + tag, + *self._args, + tokenizer_kwds=tokenizer_kwds, + trust_remote_code=trust_remote_code, + **kwds, ) else: - if self.import_kwargs: - kwds = {**self.import_kwargs, **kwds} # NOTE: In this branch, we just use the default implementation. self.__bentomodel__ = import_model( self._pretrained, tag, *self._args, tokenizer_kwds=tokenizer_kwds, + trust_remote_code=trust_remote_code, __openllm_framework__=self._implementation, **kwds, ) @@ -473,10 +493,24 @@ class LLM(LLMInterface): def Runner(start_name: str, **kwds: t.Any): + """Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models' + + Args: + start_name: Supported model name from 'openllm models' + init_local: Whether to init_local this given Runner. This is useful during development. (Default to False) + **kwds: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs + behaviour + """ + init_local = kwds.pop("init_local", False) envvar = openllm.utils.get_framework_env(start_name) if envvar == "flax": - return openllm.AutoFlaxLLM.create_runner(start_name, **kwds) + runner = openllm.AutoFlaxLLM.create_runner(start_name, **kwds) elif envvar == "tf": - return openllm.AutoTFLLM.create_runner(start_name, **kwds) + runner = openllm.AutoTFLLM.create_runner(start_name, **kwds) else: - return openllm.AutoLLM.create_runner(start_name, **kwds) + runner = openllm.AutoLLM.create_runner(start_name, **kwds) + + if init_local: + runner.init_local() + + return runner diff --git a/src/openllm/types.py b/src/openllm/_types.py similarity index 100% rename from src/openllm/types.py rename to src/openllm/_types.py diff --git a/src/openllm/models/chatglm/configuration_chatglm.py b/src/openllm/models/chatglm/configuration_chatglm.py index f3711533..120f94bb 100644 --- a/src/openllm/models/chatglm/configuration_chatglm.py +++ b/src/openllm/models/chatglm/configuration_chatglm.py @@ -16,7 +16,7 @@ from __future__ import annotations import openllm -class ChatGLMConfig(openllm.LLMConfig, name_type="lowercase"): +class ChatGLMConfig(openllm.LLMConfig, name_type="lowercase", trust_remote_code=True, default_timeout=3600000): """Configuration for the ChatGLM model.""" retain_history: bool = True diff --git a/src/openllm/models/chatglm/modeling_chatglm.py b/src/openllm/models/chatglm/modeling_chatglm.py index 19ba4022..237a4f85 100644 --- a/src/openllm/models/chatglm/modeling_chatglm.py +++ b/src/openllm/models/chatglm/modeling_chatglm.py @@ -44,6 +44,8 @@ class ChatGLM(openllm.LLM, _internal=True): variants = ["THUDM/chatglm-6b", "THUDM/chatglm-6b-int8", "THUDM/chatglm-6b-int4"] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + def model_post_init(self, _: t.Any): self.history: list[tuple[str, str]] = [] @@ -70,9 +72,9 @@ class ChatGLM(openllm.LLM, _internal=True): temperature: float | None = None, **kwargs: t.Any, ) -> t.Any: - if self.config.use_half_precision: - self.model = self.model.half() if torch.cuda.is_available(): + if self.config.use_half_precision: + self.model = self.model.half() self.model = self.model.cuda() else: self.model = self.model.float() @@ -90,19 +92,22 @@ class ChatGLM(openllm.LLM, _internal=True): prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n" prompt_text += f"[Round {len(self.history)}]\n问:{prompt}\n答:" - inputs = self.tokenizer([prompt_text], return_tensors="pt").to(self.model.device) - outputs = self.model.generate( - **inputs, - generation_config=self.config.with_options( - max_length=max_length, - num_beams=num_beams, - top_p=top_p, - temperature=temperature, - do_sample=True, - **kwargs, - ).to_generation_config(), - logits_processor=logit_processor, - ) + inputs = self.tokenizer([prompt_text], return_tensors="pt").to(self.device) + with torch.device(self.device): + outputs = self.model.generate( + **inputs, + generation_config=self.config.with_options( + max_length=max_length, + num_beams=num_beams, + top_p=top_p, + temperature=temperature, + do_sample=True, + **kwargs, + ).to_generation_config(), + logits_processor=logit_processor, + ) + if torch.cuda.is_available(): + outputs = outputs.cpu() outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] response = self.tokenizer.decode(outputs) response = self.model.process_response(response) diff --git a/src/openllm/models/dolly_v2/modeling_dolly_v2.py b/src/openllm/models/dolly_v2/modeling_dolly_v2.py index 5727a524..3c3704ba 100644 --- a/src/openllm/models/dolly_v2/modeling_dolly_v2.py +++ b/src/openllm/models/dolly_v2/modeling_dolly_v2.py @@ -25,7 +25,7 @@ from .configuration_dolly_v2 import (DEFAULT_PROMPT_TEMPLATE, END_KEY, if t.TYPE_CHECKING: import torch - from openllm.types import LLMTokenizer + from ..._types import LLMTokenizer else: torch = openllm.utils.LazyLoader("torch", globals(), "torch") @@ -61,6 +61,8 @@ class DollyV2(openllm.LLM, _internal=True): import_kwargs = {"device_map": "auto", "torch_dtype": torch.bfloat16, "_tokenizer_padding_size": "left"} + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + @torch.inference_mode() def generate( self, @@ -114,20 +116,23 @@ class DollyV2(openllm.LLM, _internal=True): else: in_b = input_ids.shape[0] - generated_sequence = self.model.generate( - input_ids=input_ids.to(self.model.device), - attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, - pad_token_id=self.tokenizer.pad_token_id, - do_sample=do_sample, - eos_token_id=eos_token_id, - generation_config=llm_config.to_generation_config(), - ) + with torch.device(self.device): + generated_sequence = self.model.generate( + input_ids=input_ids.to(self.device), + attention_mask=attention_mask.to(self.device) if attention_mask is not None else None, + pad_token_id=self.tokenizer.pad_token_id, + do_sample=do_sample, + eos_token_id=eos_token_id, + generation_config=llm_config.to_generation_config(), + ) out_b = generated_sequence.shape[0] - generated_sequence: list[list[int]] = ( - generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])[0].numpy().tolist() - ) + generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])[0] + if torch.cuda.is_available(): + generated_sequence = generated_sequence.cpu() + + generated_sequence: list[list[int]] = generated_sequence.numpy().tolist() records: list[dict[str, t.Any]] = [] for sequence in generated_sequence: # The response will be set to this variable if we can identify it.