diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index d47d142d..2bb9e60c 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -135,8 +135,10 @@ class GenerationConfig(pydantic.BaseModel): description="""Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: - `True`, where the generation stops as soon as there are `num_beams` complete candidates; - - `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; - - `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm) + - `False`, where an heuristic is applied and the generation stops when is it very unlikely to find + better candidates; + - `"never"`, where the beam search procedure only stops when there cannot be better candidates + (canonical beam search algorithm) """, ) max_time: t.Optional[float] = pydantic.Field( @@ -154,7 +156,8 @@ class GenerationConfig(pydantic.BaseModel): ) penalty_alpha: t.Optional[float] = pydantic.Field( None, - description="The values balance the model confidence and the degeneration penalty in contrastive search decoding.", + description="""The values balance the model confidence and the degeneration penalty in + contrastive search decoding.""", ) use_cache: bool = pydantic.Field( True, @@ -347,12 +350,15 @@ class GenerationConfig(pydantic.BaseModel): __openllm_env_name__: str __openllm_model_name__: str - def __init_subclass__(cls, **kwargs: t.Any) -> None: + def __init_subclass__(cls, *, _internal: bool = False, **kwargs: t.Any) -> None: + if not _internal: + raise RuntimeError( + "GenerationConfig is not meant to be used directly, " + "but you can access this via a LLMConfig.generation_config" + ) model_name = kwargs.get("model_name", None) if model_name is None: - raise RuntimeError( - "GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config" - ) + raise RuntimeError("Failed to initialize GenerationConfig subclass (missing model_name)") cls.__openllm_model_name__ = inflection.underscore(model_name) cls.__openllm_env_name__ = cls.__openllm_model_name__.upper() @@ -423,8 +429,16 @@ class LLMConfig(pydantic.BaseModel, ABC): # The following is handled via __pydantic_init_subclass__, and is only used for TYPE_CHECKING __openllm_model_name__: str = "" __openllm_start_name__: str = "" + __openllm_timeout__: int = 0 GenerationConfig: type[t.Any] = GenerationConfig + def __init_subclass__(cls, *, default_timeout: int | None = None, **kwargs: t.Any): + if default_timeout is None: + default_timeout = 3600 + cls.__openllm_timeout__ = default_timeout + + super(LLMConfig, cls).__init_subclass__(**kwargs) + @classmethod def __pydantic_init_subclass__(cls, **kwargs: t.Any): cls.__openllm_model_name__ = inflection.underscore(cls.__name__.replace("Config", "")) @@ -435,7 +449,7 @@ class LLMConfig(pydantic.BaseModel, ABC): types.new_class( cls.__name__.replace("Config", "") + "GenerationConfig", (GenerationConfig,), - {"model_name": cls.__openllm_model_name__}, + {"model_name": cls.__openllm_model_name__, "_internal": True}, ), ) cls.generation_config = generation_class.construct_from_llm_config(cls) @@ -448,8 +462,6 @@ class LLMConfig(pydantic.BaseModel, ABC): continue field.json_schema_extra["env"] = f"OPENLLM_{cls.__openllm_model_name__.upper()}_{key.upper()}" - super().__init_subclass__(**kwargs) - def model_post_init(self, _: t.Any): if self.__pydantic_extra__: generation_config = self.__pydantic_extra__.pop("generation_config", None) diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 6de06819..0df5160b 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -70,18 +70,31 @@ class TaskType(enum.Enum, metaclass=TypeMeta): TEXT2TEXT_GENERATION = enum.auto() -# NOTE: Currently, all LLMs are either text-generation or text2text-generation -# hence, the two dicts to check are -# transformers.MODEL_FOR_CAUSAL_LM_MAPPING & transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING -def import_model(model_name: str, tag: bentoml.Tag, *model_args: t.Any, **kwds: t.Any): - """Auto detect model type from given model_name and import it to bentoml's model store.""" - _framework_impl = kwds.pop("_for_framework", "pt") +def import_model(model_name: str, tag: bentoml.Tag, __openllm_framework__: str, *model_args: t.Any, **kwds: t.Any): + """Auto detect model type from given model_name and import it to bentoml's model store. + + For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first, returning all of the unused kwargs. + The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants). + For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion. + + Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`. + + Refer to Transformers documentation for more information about kwargs. + + Args: + model_name: Model name to be imported. use `openllm models` to see available entries + tag: Tag to be used for the model. This is usually generated for you. + model_args: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants). + **kwds: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants). + """ + config: transformers.PretrainedConfig = kwds.pop("config", None) trust_remote_code = kwds.pop("trust_remote_code", False) tokenizer_kwds = {k[len("_tokenizer_") :]: v for k, v in kwds.items() if k.startswith("_tokenizer_")} kwds = {k: v for k, v in kwds.items() if not k.startswith("_tokenizer_")} + # this logic below is synonymous to handling `from_pretrained` kwds. hub_kwds_names = [ "cache_dir", @@ -114,7 +127,7 @@ def import_model(model_name: str, tag: bentoml.Tag, *model_args: t.Any, **kwds: return bentoml.transformers.save_model( str(tag), getattr( - transformers, _return_tensors_to_framework_map[_framework_impl][TaskType[task_type].value - 1] + transformers, _return_tensors_to_framework_map[__openllm_framework__][TaskType[task_type].value - 1] ).from_pretrained( model_name, *model_args, config=config, trust_remote_code=trust_remote_code, **hub_kwds, **kwds ), @@ -129,7 +142,14 @@ def import_model(model_name: str, tag: bentoml.Tag, *model_args: t.Any, **kwds: ) -_reserved_namespace = {"default_model", "variants", "config_class", "model", "tokenizer"} +_reserved_namespace = { + "default_model", + "variants", + "config_class", + "model", + "tokenizer", + "import_kwargs", +} class LLMInterface(ABC): @@ -150,6 +170,10 @@ class LLMInterface(ABC): config_class: type[openllm.LLMConfig] """The config class to use for this LLM. If you are creating a custom LLM, you must specify this class.""" + import_kwargs: dict[str, t.Any] | None = None + """The default import kwargs to used when importing the model. + This will be passed into the default 'openllm._llm.import_model'.""" + @abstractmethod def generate(self, prompt: str, **kwargs: t.Any) -> t.Any: """The main function implementation for generating from given prompt.""" @@ -157,7 +181,9 @@ class LLMInterface(ABC): def generate_iterator(self, prompt: str, **kwargs: t.Any) -> t.Iterator[t.Any]: """An iterator version of generate function.""" - raise NotImplementedError + raise NotImplementedError( + "Currently generate_iterator requires SSE (Server-side events) support, which is not yet implemented." + ) if t.TYPE_CHECKING: @@ -235,23 +261,41 @@ class LLM(LLMInterface): If you need to overwrite the default ``import_model``, implement the following in your subclass: ```python - def import_model(self, pretrained: str, tag: bentoml.Tag, *args: t.Any, **kwargs: t.Any) -> bentoml.Model: - return bentoml.transformers.save_model(str(tag), ...) + def import_model(self, pretrained: str, tag: bentoml.Tag, *args: t.Any, **kwargs: t.Any): + tokenizer_kwargs = {k[len('_tokenizer_'):]: v for k, v in kwargs.items() if k.startswith('_tokenizer_')]} + kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_tokenizer_')} + return bentoml.transformers.save_model( + str(tag), + transformers.AutoModelForCausalLM.from_pretrained( + pretrained, device_map="auto", torch_dtype=torch.bfloat16, **kwargs + ), + custom_objects={"tokenizer": transformers.AutoTokenizer.from_pretrained(pretrained, padding_size="left", + **tokenizer_kwargs)}, + ) ``` - Note: See ``openllm.DollyV2`` for example + If your import model doesn't require customization, you can simply pass in `import_kwargs` at class level that will be then passed into + The default `import_model` implementation. See ``openllm.DollyV2`` for example. + + ```python + dolly_v2_runner = openllm.Runner("dolly-v2", _tokenizer_padding_size="left", torch_dtype=torch.bfloat8, device_map='gpu') + ``` + + Note: If you implement your own `import_model`, then `import_kwargs` will be ignored. Note that this tag will be generated based on `self.default_model` or the given `pretrained` kwds. passed from the __init__ constructor. Args: - pretrained: The pretrained model to use. Defaults to None. It will use self.default_model if None. - llm_config: The config to use for this LLM. Defaults to None. It will use self.config_class to construct default configuration. + pretrained: The pretrained model to use. Defaults to None. It will use 'self.default_model' if None. + llm_config: The config to use for this LLM. Defaults to None. If not passed, we will use 'self.config_class' + to construct default configuration. *args: The args to be passed to the model. **kwargs: The kwargs to be passed to the model. """ if llm_config is not None: + logger.debug("Using given 'llm_config=%s' to initialize LLM", llm_config) self.config = llm_config else: self.config = self.config_class(**kwargs) @@ -267,6 +311,7 @@ class LLM(LLMInterface): pretrained = self.default_model self._pretrained = pretrained + # NOTE: Save the args and kwargs for latter load self._args = args self._kwargs = kwargs @@ -285,9 +330,12 @@ class LLM(LLMInterface): logger.debug("Using custom 'import_model' defined in subclass.") self.__bentomodel__ = self.import_model(self._pretrained, tag, *self._args, **kwargs) else: - self._kwargs["_for_framework"] = self._implementation - # In this branch, we just use the default implementation. - self.__bentomodel__ = import_model(self._pretrained, tag, *self._args, **kwargs) + if self.import_kwargs: + kwargs = {**self.import_kwargs, **kwargs} + # NOTE: In this branch, we just use the default implementation. + self.__bentomodel__ = import_model( + self._pretrained, tag, __openllm_framework__=self._implementation, *self._args, **kwargs + ) return self.__bentomodel__ @property @@ -344,7 +392,7 @@ class LLM(LLMInterface): name = f"llm-{self.config.__openllm_start_name__}-runner" models = models if models is not None else [] - # NOTE: The side effect of this is that i will load the imported model during runner creation. + # NOTE: The side effect of this is that will load the imported model during runner creation. models.append(self._bentomodel) if scheduling_strategy is None: @@ -352,25 +400,36 @@ class LLM(LLMInterface): scheduling_strategy = DefaultStrategy - signature = ModelSignature.from_dict(ModelSignatureDict(batchable=False)) + generate_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=False)) + generate_iterator_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=True)) if method_configs is None: - method_configs = {"generate": signature} + method_configs = {"generate": generate_sig, "generate_iterator": generate_iterator_sig} else: - signature = ModelSignature.convert_signatures_dict(method_configs).get("generate", signature) + generate_sig = ModelSignature.convert_signatures_dict(method_configs).get("generate", generate_sig) + ModelSignature.convert_signatures_dict(method_configs).get("generate_iterator", generate_iterator_sig) class _Runnable(bentoml.Runnable): SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") SUPPORTS_CPU_MULTI_THREADING = True @bentoml.Runnable.method( - batchable=signature.batchable, - batch_dim=signature.batch_dim, - input_spec=signature.input_spec, - output_spec=signature.output_spec, + batchable=generate_sig.batchable, + batch_dim=generate_sig.batch_dim, + input_spec=generate_sig.input_spec, + output_spec=generate_sig.output_spec, ) def generate(__self, prompt: str, **kwds: t.Any) -> list[str]: return self.generate(prompt, **kwds) + @bentoml.Runnable.method( + batchable=generate_iterator_sig.batchable, + batch_dim=generate_iterator_sig.batch_dim, + input_spec=generate_iterator_sig.input_spec, + output_spec=generate_iterator_sig.output_spec, + ) + def generate_iterator(__self, prompt: str, **kwds: t.Any) -> t.Iterator[str]: + return self.generate_iterator(prompt, **kwds) + return bentoml.Runner( t.cast( "type[LLMRunnable]", diff --git a/src/openllm/cli.py b/src/openllm/cli.py index f54a3b57..3cada2e9 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -18,6 +18,7 @@ This extends clidantic and BentoML's internal CLI CommandGroup. """ from __future__ import annotations +import copy import difflib import functools import inspect @@ -270,7 +271,8 @@ def start_model_command( @config.to_click_options @parse_serve_args(_serve_grpc) @OpenLLMCommandGroup.common_chain - def model_start(**attrs: t.Any): + @click.option("--server-timeout", type=int, default=3600, help="Server timeout in seconds") + def model_start(server_timeout: int, **attrs: t.Any): from bentoml._internal.configuration import get_debug_mode from bentoml._internal.log import configure_logging @@ -286,12 +288,24 @@ def start_model_command( server_kwds.setdefault("production", not development) start_env = os.environ.copy() + + # NOTE: This is a hack to set current configuration + _bentoml_config_options = start_env.pop("BENTOML_CONFIG_OPTIONS", "") + _bentoml_config_options += ( + " " + if _bentoml_config_options + else "" + + f"api_server.timeout={server_timeout}" + + f' runners."llm-{config.__openllm_start_name__}-runner".timeout={config.__openllm_timeout__}' + ) + start_env.update( { openllm.utils.FRAMEWORK_ENV_VAR(model_name): openllm.utils.get_framework_env(model_name), openllm.utils.MODEL_CONFIG_ENV_VAR(model_name): nw_config.model_dump_json(), "OPENLLM_MODEL": model_name, "BENTOML_DEBUG": str(get_debug_mode()), + "BENTOML_CONFIG_OPTIONS": _bentoml_config_options, } ) diff --git a/src/openllm/models/dolly_v2/configuration_dolly_v2.py b/src/openllm/models/dolly_v2/configuration_dolly_v2.py index 8b1e2fec..0c041181 100644 --- a/src/openllm/models/dolly_v2/configuration_dolly_v2.py +++ b/src/openllm/models/dolly_v2/configuration_dolly_v2.py @@ -16,15 +16,10 @@ from __future__ import annotations -import typing as t - import openllm -if t.TYPE_CHECKING: - from openllm.types import LLMTokenizer - -class DollyV2Config(openllm.LLMConfig): +class DollyV2Config(openllm.LLMConfig, default_timeout=3600000): """Configuration for the dolly-v2 model.""" return_full_text: bool = False @@ -72,25 +67,3 @@ DEFAULT_PROMPT_TEMPLATE = """{intro} instruction="{instruction}", response_key=RESPONSE_KEY, ) - - -def get_special_token_id(tokenizer: LLMTokenizer, key: str) -> int: - """ - Gets the token ID for a given string that has been added to the tokenizer as a special token. - When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are - treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to. - - Args: - tokenizer: the tokenizer - key: the key to convert to a single token - - Raises: - RuntimeError: if more than one ID was generated - - Returns: - int: the token ID for the given key - """ - token_ids = tokenizer.encode(key) - if len(token_ids) > 1: - raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}") - return token_ids[0] diff --git a/src/openllm/models/dolly_v2/modeling_dolly_v2.py b/src/openllm/models/dolly_v2/modeling_dolly_v2.py index 29dd8533..5727a524 100644 --- a/src/openllm/models/dolly_v2/modeling_dolly_v2.py +++ b/src/openllm/models/dolly_v2/modeling_dolly_v2.py @@ -17,36 +17,49 @@ import logging import re import typing as t -import bentoml - import openllm from .configuration_dolly_v2 import (DEFAULT_PROMPT_TEMPLATE, END_KEY, - RESPONSE_KEY, get_special_token_id) + RESPONSE_KEY) if t.TYPE_CHECKING: import torch - import transformers + + from openllm.types import LLMTokenizer else: torch = openllm.utils.LazyLoader("torch", globals(), "torch") - transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") logger = logging.getLogger(__name__) +def get_special_token_id(tokenizer: LLMTokenizer, key: str) -> int: + """ + Gets the token ID for a given string that has been added to the tokenizer as a special token. + When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are + treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to. + + Args: + tokenizer: the tokenizer + key: the key to convert to a single token + + Raises: + RuntimeError: if more than one ID was generated + + Returns: + int: the token ID for the given key + """ + token_ids = tokenizer.encode(key) + if len(token_ids) > 1: + raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}") + return token_ids[0] + + class DollyV2(openllm.LLM, _internal=True): default_model = "databricks/dolly-v2-3b" variants = ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"] - def import_model(self, pretrained: str, tag: bentoml.Tag, *args: t.Any, **kwargs: t.Any): - return bentoml.transformers.save_model( - str(tag), - transformers.AutoModelForCausalLM.from_pretrained( - pretrained, device_map="auto", torch_dtype=torch.bfloat16 - ), - custom_objects={"tokenizer": transformers.AutoTokenizer.from_pretrained(pretrained, padding_size="left")}, - ) + import_kwargs = {"device_map": "auto", "torch_dtype": torch.bfloat16, "_tokenizer_padding_size": "left"} @torch.inference_mode() def generate( @@ -58,7 +71,6 @@ class DollyV2(openllm.LLM, _internal=True): top_k: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, - eos_token_id: int | None = None, **kwargs: t.Any, ): """This is a implementation of InstructionTextGenerationPipeline from databricks.""" @@ -67,6 +79,7 @@ class DollyV2(openllm.LLM, _internal=True): ) response_key_token_id = None end_key_token_id = None + eos_token_id = None llm_config = self.config.with_options( max_length=max_length,