feat(timeout): support server_timeout and LLM timeout

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-05-23 16:48:01 -07:00
parent b1c07946c1
commit 162c021cae
5 changed files with 149 additions and 78 deletions

View File

@@ -135,8 +135,10 @@ class GenerationConfig(pydantic.BaseModel):
description="""Controls the stopping condition for beam-based methods, like beam-search. It accepts the
following values:
- `True`, where the generation stops as soon as there are `num_beams` complete candidates;
- `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates;
- `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm)
- `False`, where an heuristic is applied and the generation stops when is it very unlikely to find
better candidates;
- `"never"`, where the beam search procedure only stops when there cannot be better candidates
(canonical beam search algorithm)
""",
)
max_time: t.Optional[float] = pydantic.Field(
@@ -154,7 +156,8 @@ class GenerationConfig(pydantic.BaseModel):
)
penalty_alpha: t.Optional[float] = pydantic.Field(
None,
description="The values balance the model confidence and the degeneration penalty in contrastive search decoding.",
description="""The values balance the model confidence and the degeneration penalty in
contrastive search decoding.""",
)
use_cache: bool = pydantic.Field(
True,
@@ -347,12 +350,15 @@ class GenerationConfig(pydantic.BaseModel):
__openllm_env_name__: str
__openllm_model_name__: str
def __init_subclass__(cls, **kwargs: t.Any) -> None:
def __init_subclass__(cls, *, _internal: bool = False, **kwargs: t.Any) -> None:
if not _internal:
raise RuntimeError(
"GenerationConfig is not meant to be used directly, "
"but you can access this via a LLMConfig.generation_config"
)
model_name = kwargs.get("model_name", None)
if model_name is None:
raise RuntimeError(
"GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config"
)
raise RuntimeError("Failed to initialize GenerationConfig subclass (missing model_name)")
cls.__openllm_model_name__ = inflection.underscore(model_name)
cls.__openllm_env_name__ = cls.__openllm_model_name__.upper()
@@ -423,8 +429,16 @@ class LLMConfig(pydantic.BaseModel, ABC):
# The following is handled via __pydantic_init_subclass__, and is only used for TYPE_CHECKING
__openllm_model_name__: str = ""
__openllm_start_name__: str = ""
__openllm_timeout__: int = 0
GenerationConfig: type[t.Any] = GenerationConfig
def __init_subclass__(cls, *, default_timeout: int | None = None, **kwargs: t.Any):
if default_timeout is None:
default_timeout = 3600
cls.__openllm_timeout__ = default_timeout
super(LLMConfig, cls).__init_subclass__(**kwargs)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs: t.Any):
cls.__openllm_model_name__ = inflection.underscore(cls.__name__.replace("Config", ""))
@@ -435,7 +449,7 @@ class LLMConfig(pydantic.BaseModel, ABC):
types.new_class(
cls.__name__.replace("Config", "") + "GenerationConfig",
(GenerationConfig,),
{"model_name": cls.__openllm_model_name__},
{"model_name": cls.__openllm_model_name__, "_internal": True},
),
)
cls.generation_config = generation_class.construct_from_llm_config(cls)
@@ -448,8 +462,6 @@ class LLMConfig(pydantic.BaseModel, ABC):
continue
field.json_schema_extra["env"] = f"OPENLLM_{cls.__openllm_model_name__.upper()}_{key.upper()}"
super().__init_subclass__(**kwargs)
def model_post_init(self, _: t.Any):
if self.__pydantic_extra__:
generation_config = self.__pydantic_extra__.pop("generation_config", None)

View File

@@ -70,18 +70,31 @@ class TaskType(enum.Enum, metaclass=TypeMeta):
TEXT2TEXT_GENERATION = enum.auto()
# NOTE: Currently, all LLMs are either text-generation or text2text-generation
# hence, the two dicts to check are
# transformers.MODEL_FOR_CAUSAL_LM_MAPPING & transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
def import_model(model_name: str, tag: bentoml.Tag, *model_args: t.Any, **kwds: t.Any):
"""Auto detect model type from given model_name and import it to bentoml's model store."""
_framework_impl = kwds.pop("_for_framework", "pt")
def import_model(model_name: str, tag: bentoml.Tag, __openllm_framework__: str, *model_args: t.Any, **kwds: t.Any):
"""Auto detect model type from given model_name and import it to bentoml's model store.
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first, returning all of the unused kwargs.
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion.
Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`.
Refer to Transformers documentation for more information about kwargs.
Args:
model_name: Model name to be imported. use `openllm models` to see available entries
tag: Tag to be used for the model. This is usually generated for you.
model_args: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
**kwds: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
"""
config: transformers.PretrainedConfig = kwds.pop("config", None)
trust_remote_code = kwds.pop("trust_remote_code", False)
tokenizer_kwds = {k[len("_tokenizer_") :]: v for k, v in kwds.items() if k.startswith("_tokenizer_")}
kwds = {k: v for k, v in kwds.items() if not k.startswith("_tokenizer_")}
# this logic below is synonymous to handling `from_pretrained` kwds.
hub_kwds_names = [
"cache_dir",
@@ -114,7 +127,7 @@ def import_model(model_name: str, tag: bentoml.Tag, *model_args: t.Any, **kwds:
return bentoml.transformers.save_model(
str(tag),
getattr(
transformers, _return_tensors_to_framework_map[_framework_impl][TaskType[task_type].value - 1]
transformers, _return_tensors_to_framework_map[__openllm_framework__][TaskType[task_type].value - 1]
).from_pretrained(
model_name, *model_args, config=config, trust_remote_code=trust_remote_code, **hub_kwds, **kwds
),
@@ -129,7 +142,14 @@ def import_model(model_name: str, tag: bentoml.Tag, *model_args: t.Any, **kwds:
)
_reserved_namespace = {"default_model", "variants", "config_class", "model", "tokenizer"}
_reserved_namespace = {
"default_model",
"variants",
"config_class",
"model",
"tokenizer",
"import_kwargs",
}
class LLMInterface(ABC):
@@ -150,6 +170,10 @@ class LLMInterface(ABC):
config_class: type[openllm.LLMConfig]
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
import_kwargs: dict[str, t.Any] | None = None
"""The default import kwargs to used when importing the model.
This will be passed into the default 'openllm._llm.import_model'."""
@abstractmethod
def generate(self, prompt: str, **kwargs: t.Any) -> t.Any:
"""The main function implementation for generating from given prompt."""
@@ -157,7 +181,9 @@ class LLMInterface(ABC):
def generate_iterator(self, prompt: str, **kwargs: t.Any) -> t.Iterator[t.Any]:
"""An iterator version of generate function."""
raise NotImplementedError
raise NotImplementedError(
"Currently generate_iterator requires SSE (Server-side events) support, which is not yet implemented."
)
if t.TYPE_CHECKING:
@@ -235,23 +261,41 @@ class LLM(LLMInterface):
If you need to overwrite the default ``import_model``, implement the following in your subclass:
```python
def import_model(self, pretrained: str, tag: bentoml.Tag, *args: t.Any, **kwargs: t.Any) -> bentoml.Model:
return bentoml.transformers.save_model(str(tag), ...)
def import_model(self, pretrained: str, tag: bentoml.Tag, *args: t.Any, **kwargs: t.Any):
tokenizer_kwargs = {k[len('_tokenizer_'):]: v for k, v in kwargs.items() if k.startswith('_tokenizer_')]}
kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_tokenizer_')}
return bentoml.transformers.save_model(
str(tag),
transformers.AutoModelForCausalLM.from_pretrained(
pretrained, device_map="auto", torch_dtype=torch.bfloat16, **kwargs
),
custom_objects={"tokenizer": transformers.AutoTokenizer.from_pretrained(pretrained, padding_size="left",
**tokenizer_kwargs)},
)
```
Note: See ``openllm.DollyV2`` for example
If your import model doesn't require customization, you can simply pass in `import_kwargs` at class level that will be then passed into
The default `import_model` implementation. See ``openllm.DollyV2`` for example.
```python
dolly_v2_runner = openllm.Runner("dolly-v2", _tokenizer_padding_size="left", torch_dtype=torch.bfloat8, device_map='gpu')
```
Note: If you implement your own `import_model`, then `import_kwargs` will be ignored.
Note that this tag will be generated based on `self.default_model` or the given `pretrained` kwds.
passed from the __init__ constructor.
Args:
pretrained: The pretrained model to use. Defaults to None. It will use self.default_model if None.
llm_config: The config to use for this LLM. Defaults to None. It will use self.config_class to construct default configuration.
pretrained: The pretrained model to use. Defaults to None. It will use 'self.default_model' if None.
llm_config: The config to use for this LLM. Defaults to None. If not passed, we will use 'self.config_class'
to construct default configuration.
*args: The args to be passed to the model.
**kwargs: The kwargs to be passed to the model.
"""
if llm_config is not None:
logger.debug("Using given 'llm_config=%s' to initialize LLM", llm_config)
self.config = llm_config
else:
self.config = self.config_class(**kwargs)
@@ -267,6 +311,7 @@ class LLM(LLMInterface):
pretrained = self.default_model
self._pretrained = pretrained
# NOTE: Save the args and kwargs for latter load
self._args = args
self._kwargs = kwargs
@@ -285,9 +330,12 @@ class LLM(LLMInterface):
logger.debug("Using custom 'import_model' defined in subclass.")
self.__bentomodel__ = self.import_model(self._pretrained, tag, *self._args, **kwargs)
else:
self._kwargs["_for_framework"] = self._implementation
# In this branch, we just use the default implementation.
self.__bentomodel__ = import_model(self._pretrained, tag, *self._args, **kwargs)
if self.import_kwargs:
kwargs = {**self.import_kwargs, **kwargs}
# NOTE: In this branch, we just use the default implementation.
self.__bentomodel__ = import_model(
self._pretrained, tag, __openllm_framework__=self._implementation, *self._args, **kwargs
)
return self.__bentomodel__
@property
@@ -344,7 +392,7 @@ class LLM(LLMInterface):
name = f"llm-{self.config.__openllm_start_name__}-runner"
models = models if models is not None else []
# NOTE: The side effect of this is that i will load the imported model during runner creation.
# NOTE: The side effect of this is that will load the imported model during runner creation.
models.append(self._bentomodel)
if scheduling_strategy is None:
@@ -352,25 +400,36 @@ class LLM(LLMInterface):
scheduling_strategy = DefaultStrategy
signature = ModelSignature.from_dict(ModelSignatureDict(batchable=False))
generate_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=False))
generate_iterator_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=True))
if method_configs is None:
method_configs = {"generate": signature}
method_configs = {"generate": generate_sig, "generate_iterator": generate_iterator_sig}
else:
signature = ModelSignature.convert_signatures_dict(method_configs).get("generate", signature)
generate_sig = ModelSignature.convert_signatures_dict(method_configs).get("generate", generate_sig)
ModelSignature.convert_signatures_dict(method_configs).get("generate_iterator", generate_iterator_sig)
class _Runnable(bentoml.Runnable):
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True
@bentoml.Runnable.method(
batchable=signature.batchable,
batch_dim=signature.batch_dim,
input_spec=signature.input_spec,
output_spec=signature.output_spec,
batchable=generate_sig.batchable,
batch_dim=generate_sig.batch_dim,
input_spec=generate_sig.input_spec,
output_spec=generate_sig.output_spec,
)
def generate(__self, prompt: str, **kwds: t.Any) -> list[str]:
return self.generate(prompt, **kwds)
@bentoml.Runnable.method(
batchable=generate_iterator_sig.batchable,
batch_dim=generate_iterator_sig.batch_dim,
input_spec=generate_iterator_sig.input_spec,
output_spec=generate_iterator_sig.output_spec,
)
def generate_iterator(__self, prompt: str, **kwds: t.Any) -> t.Iterator[str]:
return self.generate_iterator(prompt, **kwds)
return bentoml.Runner(
t.cast(
"type[LLMRunnable]",

View File

@@ -18,6 +18,7 @@ This extends clidantic and BentoML's internal CLI CommandGroup.
"""
from __future__ import annotations
import copy
import difflib
import functools
import inspect
@@ -270,7 +271,8 @@ def start_model_command(
@config.to_click_options
@parse_serve_args(_serve_grpc)
@OpenLLMCommandGroup.common_chain
def model_start(**attrs: t.Any):
@click.option("--server-timeout", type=int, default=3600, help="Server timeout in seconds")
def model_start(server_timeout: int, **attrs: t.Any):
from bentoml._internal.configuration import get_debug_mode
from bentoml._internal.log import configure_logging
@@ -286,12 +288,24 @@ def start_model_command(
server_kwds.setdefault("production", not development)
start_env = os.environ.copy()
# NOTE: This is a hack to set current configuration
_bentoml_config_options = start_env.pop("BENTOML_CONFIG_OPTIONS", "")
_bentoml_config_options += (
" "
if _bentoml_config_options
else ""
+ f"api_server.timeout={server_timeout}"
+ f' runners."llm-{config.__openllm_start_name__}-runner".timeout={config.__openllm_timeout__}'
)
start_env.update(
{
openllm.utils.FRAMEWORK_ENV_VAR(model_name): openllm.utils.get_framework_env(model_name),
openllm.utils.MODEL_CONFIG_ENV_VAR(model_name): nw_config.model_dump_json(),
"OPENLLM_MODEL": model_name,
"BENTOML_DEBUG": str(get_debug_mode()),
"BENTOML_CONFIG_OPTIONS": _bentoml_config_options,
}
)

View File

@@ -16,15 +16,10 @@
from __future__ import annotations
import typing as t
import openllm
if t.TYPE_CHECKING:
from openllm.types import LLMTokenizer
class DollyV2Config(openllm.LLMConfig):
class DollyV2Config(openllm.LLMConfig, default_timeout=3600000):
"""Configuration for the dolly-v2 model."""
return_full_text: bool = False
@@ -72,25 +67,3 @@ DEFAULT_PROMPT_TEMPLATE = """{intro}
instruction="{instruction}",
response_key=RESPONSE_KEY,
)
def get_special_token_id(tokenizer: LLMTokenizer, key: str) -> int:
"""
Gets the token ID for a given string that has been added to the tokenizer as a special token.
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
Args:
tokenizer: the tokenizer
key: the key to convert to a single token
Raises:
RuntimeError: if more than one ID was generated
Returns:
int: the token ID for the given key
"""
token_ids = tokenizer.encode(key)
if len(token_ids) > 1:
raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
return token_ids[0]

View File

@@ -17,36 +17,49 @@ import logging
import re
import typing as t
import bentoml
import openllm
from .configuration_dolly_v2 import (DEFAULT_PROMPT_TEMPLATE, END_KEY,
RESPONSE_KEY, get_special_token_id)
RESPONSE_KEY)
if t.TYPE_CHECKING:
import torch
import transformers
from openllm.types import LLMTokenizer
else:
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
def get_special_token_id(tokenizer: LLMTokenizer, key: str) -> int:
"""
Gets the token ID for a given string that has been added to the tokenizer as a special token.
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
Args:
tokenizer: the tokenizer
key: the key to convert to a single token
Raises:
RuntimeError: if more than one ID was generated
Returns:
int: the token ID for the given key
"""
token_ids = tokenizer.encode(key)
if len(token_ids) > 1:
raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
return token_ids[0]
class DollyV2(openllm.LLM, _internal=True):
default_model = "databricks/dolly-v2-3b"
variants = ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"]
def import_model(self, pretrained: str, tag: bentoml.Tag, *args: t.Any, **kwargs: t.Any):
return bentoml.transformers.save_model(
str(tag),
transformers.AutoModelForCausalLM.from_pretrained(
pretrained, device_map="auto", torch_dtype=torch.bfloat16
),
custom_objects={"tokenizer": transformers.AutoTokenizer.from_pretrained(pretrained, padding_size="left")},
)
import_kwargs = {"device_map": "auto", "torch_dtype": torch.bfloat16, "_tokenizer_padding_size": "left"}
@torch.inference_mode()
def generate(
@@ -58,7 +71,6 @@ class DollyV2(openllm.LLM, _internal=True):
top_k: float | None = None,
top_p: float | None = None,
max_new_tokens: int | None = None,
eos_token_id: int | None = None,
**kwargs: t.Any,
):
"""This is a implementation of InstructionTextGenerationPipeline from databricks."""
@@ -67,6 +79,7 @@ class DollyV2(openllm.LLM, _internal=True):
)
response_key_token_id = None
end_key_token_id = None
eos_token_id = None
llm_config = self.config.with_options(
max_length=max_length,