mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-22 22:39:47 -05:00
feat(timeout): support server_timeout and LLM timeout
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -135,8 +135,10 @@ class GenerationConfig(pydantic.BaseModel):
|
||||
description="""Controls the stopping condition for beam-based methods, like beam-search. It accepts the
|
||||
following values:
|
||||
- `True`, where the generation stops as soon as there are `num_beams` complete candidates;
|
||||
- `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
||||
- `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm)
|
||||
- `False`, where an heuristic is applied and the generation stops when is it very unlikely to find
|
||||
better candidates;
|
||||
- `"never"`, where the beam search procedure only stops when there cannot be better candidates
|
||||
(canonical beam search algorithm)
|
||||
""",
|
||||
)
|
||||
max_time: t.Optional[float] = pydantic.Field(
|
||||
@@ -154,7 +156,8 @@ class GenerationConfig(pydantic.BaseModel):
|
||||
)
|
||||
penalty_alpha: t.Optional[float] = pydantic.Field(
|
||||
None,
|
||||
description="The values balance the model confidence and the degeneration penalty in contrastive search decoding.",
|
||||
description="""The values balance the model confidence and the degeneration penalty in
|
||||
contrastive search decoding.""",
|
||||
)
|
||||
use_cache: bool = pydantic.Field(
|
||||
True,
|
||||
@@ -347,12 +350,15 @@ class GenerationConfig(pydantic.BaseModel):
|
||||
__openllm_env_name__: str
|
||||
__openllm_model_name__: str
|
||||
|
||||
def __init_subclass__(cls, **kwargs: t.Any) -> None:
|
||||
def __init_subclass__(cls, *, _internal: bool = False, **kwargs: t.Any) -> None:
|
||||
if not _internal:
|
||||
raise RuntimeError(
|
||||
"GenerationConfig is not meant to be used directly, "
|
||||
"but you can access this via a LLMConfig.generation_config"
|
||||
)
|
||||
model_name = kwargs.get("model_name", None)
|
||||
if model_name is None:
|
||||
raise RuntimeError(
|
||||
"GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config"
|
||||
)
|
||||
raise RuntimeError("Failed to initialize GenerationConfig subclass (missing model_name)")
|
||||
cls.__openllm_model_name__ = inflection.underscore(model_name)
|
||||
cls.__openllm_env_name__ = cls.__openllm_model_name__.upper()
|
||||
|
||||
@@ -423,8 +429,16 @@ class LLMConfig(pydantic.BaseModel, ABC):
|
||||
# The following is handled via __pydantic_init_subclass__, and is only used for TYPE_CHECKING
|
||||
__openllm_model_name__: str = ""
|
||||
__openllm_start_name__: str = ""
|
||||
__openllm_timeout__: int = 0
|
||||
GenerationConfig: type[t.Any] = GenerationConfig
|
||||
|
||||
def __init_subclass__(cls, *, default_timeout: int | None = None, **kwargs: t.Any):
|
||||
if default_timeout is None:
|
||||
default_timeout = 3600
|
||||
cls.__openllm_timeout__ = default_timeout
|
||||
|
||||
super(LLMConfig, cls).__init_subclass__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs: t.Any):
|
||||
cls.__openllm_model_name__ = inflection.underscore(cls.__name__.replace("Config", ""))
|
||||
@@ -435,7 +449,7 @@ class LLMConfig(pydantic.BaseModel, ABC):
|
||||
types.new_class(
|
||||
cls.__name__.replace("Config", "") + "GenerationConfig",
|
||||
(GenerationConfig,),
|
||||
{"model_name": cls.__openllm_model_name__},
|
||||
{"model_name": cls.__openllm_model_name__, "_internal": True},
|
||||
),
|
||||
)
|
||||
cls.generation_config = generation_class.construct_from_llm_config(cls)
|
||||
@@ -448,8 +462,6 @@ class LLMConfig(pydantic.BaseModel, ABC):
|
||||
continue
|
||||
field.json_schema_extra["env"] = f"OPENLLM_{cls.__openllm_model_name__.upper()}_{key.upper()}"
|
||||
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
def model_post_init(self, _: t.Any):
|
||||
if self.__pydantic_extra__:
|
||||
generation_config = self.__pydantic_extra__.pop("generation_config", None)
|
||||
|
||||
@@ -70,18 +70,31 @@ class TaskType(enum.Enum, metaclass=TypeMeta):
|
||||
TEXT2TEXT_GENERATION = enum.auto()
|
||||
|
||||
|
||||
# NOTE: Currently, all LLMs are either text-generation or text2text-generation
|
||||
# hence, the two dicts to check are
|
||||
# transformers.MODEL_FOR_CAUSAL_LM_MAPPING & transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
def import_model(model_name: str, tag: bentoml.Tag, *model_args: t.Any, **kwds: t.Any):
|
||||
"""Auto detect model type from given model_name and import it to bentoml's model store."""
|
||||
_framework_impl = kwds.pop("_for_framework", "pt")
|
||||
def import_model(model_name: str, tag: bentoml.Tag, __openllm_framework__: str, *model_args: t.Any, **kwds: t.Any):
|
||||
"""Auto detect model type from given model_name and import it to bentoml's model store.
|
||||
|
||||
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first, returning all of the unused kwargs.
|
||||
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion.
|
||||
|
||||
Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`.
|
||||
|
||||
Refer to Transformers documentation for more information about kwargs.
|
||||
|
||||
Args:
|
||||
model_name: Model name to be imported. use `openllm models` to see available entries
|
||||
tag: Tag to be used for the model. This is usually generated for you.
|
||||
model_args: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
**kwds: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
"""
|
||||
|
||||
config: transformers.PretrainedConfig = kwds.pop("config", None)
|
||||
trust_remote_code = kwds.pop("trust_remote_code", False)
|
||||
|
||||
tokenizer_kwds = {k[len("_tokenizer_") :]: v for k, v in kwds.items() if k.startswith("_tokenizer_")}
|
||||
|
||||
kwds = {k: v for k, v in kwds.items() if not k.startswith("_tokenizer_")}
|
||||
|
||||
# this logic below is synonymous to handling `from_pretrained` kwds.
|
||||
hub_kwds_names = [
|
||||
"cache_dir",
|
||||
@@ -114,7 +127,7 @@ def import_model(model_name: str, tag: bentoml.Tag, *model_args: t.Any, **kwds:
|
||||
return bentoml.transformers.save_model(
|
||||
str(tag),
|
||||
getattr(
|
||||
transformers, _return_tensors_to_framework_map[_framework_impl][TaskType[task_type].value - 1]
|
||||
transformers, _return_tensors_to_framework_map[__openllm_framework__][TaskType[task_type].value - 1]
|
||||
).from_pretrained(
|
||||
model_name, *model_args, config=config, trust_remote_code=trust_remote_code, **hub_kwds, **kwds
|
||||
),
|
||||
@@ -129,7 +142,14 @@ def import_model(model_name: str, tag: bentoml.Tag, *model_args: t.Any, **kwds:
|
||||
)
|
||||
|
||||
|
||||
_reserved_namespace = {"default_model", "variants", "config_class", "model", "tokenizer"}
|
||||
_reserved_namespace = {
|
||||
"default_model",
|
||||
"variants",
|
||||
"config_class",
|
||||
"model",
|
||||
"tokenizer",
|
||||
"import_kwargs",
|
||||
}
|
||||
|
||||
|
||||
class LLMInterface(ABC):
|
||||
@@ -150,6 +170,10 @@ class LLMInterface(ABC):
|
||||
config_class: type[openllm.LLMConfig]
|
||||
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
|
||||
|
||||
import_kwargs: dict[str, t.Any] | None = None
|
||||
"""The default import kwargs to used when importing the model.
|
||||
This will be passed into the default 'openllm._llm.import_model'."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, prompt: str, **kwargs: t.Any) -> t.Any:
|
||||
"""The main function implementation for generating from given prompt."""
|
||||
@@ -157,7 +181,9 @@ class LLMInterface(ABC):
|
||||
|
||||
def generate_iterator(self, prompt: str, **kwargs: t.Any) -> t.Iterator[t.Any]:
|
||||
"""An iterator version of generate function."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(
|
||||
"Currently generate_iterator requires SSE (Server-side events) support, which is not yet implemented."
|
||||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -235,23 +261,41 @@ class LLM(LLMInterface):
|
||||
If you need to overwrite the default ``import_model``, implement the following in your subclass:
|
||||
|
||||
```python
|
||||
def import_model(self, pretrained: str, tag: bentoml.Tag, *args: t.Any, **kwargs: t.Any) -> bentoml.Model:
|
||||
return bentoml.transformers.save_model(str(tag), ...)
|
||||
def import_model(self, pretrained: str, tag: bentoml.Tag, *args: t.Any, **kwargs: t.Any):
|
||||
tokenizer_kwargs = {k[len('_tokenizer_'):]: v for k, v in kwargs.items() if k.startswith('_tokenizer_')]}
|
||||
kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_tokenizer_')}
|
||||
return bentoml.transformers.save_model(
|
||||
str(tag),
|
||||
transformers.AutoModelForCausalLM.from_pretrained(
|
||||
pretrained, device_map="auto", torch_dtype=torch.bfloat16, **kwargs
|
||||
),
|
||||
custom_objects={"tokenizer": transformers.AutoTokenizer.from_pretrained(pretrained, padding_size="left",
|
||||
**tokenizer_kwargs)},
|
||||
)
|
||||
```
|
||||
|
||||
Note: See ``openllm.DollyV2`` for example
|
||||
If your import model doesn't require customization, you can simply pass in `import_kwargs` at class level that will be then passed into
|
||||
The default `import_model` implementation. See ``openllm.DollyV2`` for example.
|
||||
|
||||
```python
|
||||
dolly_v2_runner = openllm.Runner("dolly-v2", _tokenizer_padding_size="left", torch_dtype=torch.bfloat8, device_map='gpu')
|
||||
```
|
||||
|
||||
Note: If you implement your own `import_model`, then `import_kwargs` will be ignored.
|
||||
|
||||
Note that this tag will be generated based on `self.default_model` or the given `pretrained` kwds.
|
||||
passed from the __init__ constructor.
|
||||
|
||||
Args:
|
||||
pretrained: The pretrained model to use. Defaults to None. It will use self.default_model if None.
|
||||
llm_config: The config to use for this LLM. Defaults to None. It will use self.config_class to construct default configuration.
|
||||
pretrained: The pretrained model to use. Defaults to None. It will use 'self.default_model' if None.
|
||||
llm_config: The config to use for this LLM. Defaults to None. If not passed, we will use 'self.config_class'
|
||||
to construct default configuration.
|
||||
*args: The args to be passed to the model.
|
||||
**kwargs: The kwargs to be passed to the model.
|
||||
"""
|
||||
|
||||
if llm_config is not None:
|
||||
logger.debug("Using given 'llm_config=%s' to initialize LLM", llm_config)
|
||||
self.config = llm_config
|
||||
else:
|
||||
self.config = self.config_class(**kwargs)
|
||||
@@ -267,6 +311,7 @@ class LLM(LLMInterface):
|
||||
pretrained = self.default_model
|
||||
|
||||
self._pretrained = pretrained
|
||||
|
||||
# NOTE: Save the args and kwargs for latter load
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
@@ -285,9 +330,12 @@ class LLM(LLMInterface):
|
||||
logger.debug("Using custom 'import_model' defined in subclass.")
|
||||
self.__bentomodel__ = self.import_model(self._pretrained, tag, *self._args, **kwargs)
|
||||
else:
|
||||
self._kwargs["_for_framework"] = self._implementation
|
||||
# In this branch, we just use the default implementation.
|
||||
self.__bentomodel__ = import_model(self._pretrained, tag, *self._args, **kwargs)
|
||||
if self.import_kwargs:
|
||||
kwargs = {**self.import_kwargs, **kwargs}
|
||||
# NOTE: In this branch, we just use the default implementation.
|
||||
self.__bentomodel__ = import_model(
|
||||
self._pretrained, tag, __openllm_framework__=self._implementation, *self._args, **kwargs
|
||||
)
|
||||
return self.__bentomodel__
|
||||
|
||||
@property
|
||||
@@ -344,7 +392,7 @@ class LLM(LLMInterface):
|
||||
name = f"llm-{self.config.__openllm_start_name__}-runner"
|
||||
models = models if models is not None else []
|
||||
|
||||
# NOTE: The side effect of this is that i will load the imported model during runner creation.
|
||||
# NOTE: The side effect of this is that will load the imported model during runner creation.
|
||||
models.append(self._bentomodel)
|
||||
|
||||
if scheduling_strategy is None:
|
||||
@@ -352,25 +400,36 @@ class LLM(LLMInterface):
|
||||
|
||||
scheduling_strategy = DefaultStrategy
|
||||
|
||||
signature = ModelSignature.from_dict(ModelSignatureDict(batchable=False))
|
||||
generate_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=False))
|
||||
generate_iterator_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=True))
|
||||
if method_configs is None:
|
||||
method_configs = {"generate": signature}
|
||||
method_configs = {"generate": generate_sig, "generate_iterator": generate_iterator_sig}
|
||||
else:
|
||||
signature = ModelSignature.convert_signatures_dict(method_configs).get("generate", signature)
|
||||
generate_sig = ModelSignature.convert_signatures_dict(method_configs).get("generate", generate_sig)
|
||||
ModelSignature.convert_signatures_dict(method_configs).get("generate_iterator", generate_iterator_sig)
|
||||
|
||||
class _Runnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
@bentoml.Runnable.method(
|
||||
batchable=signature.batchable,
|
||||
batch_dim=signature.batch_dim,
|
||||
input_spec=signature.input_spec,
|
||||
output_spec=signature.output_spec,
|
||||
batchable=generate_sig.batchable,
|
||||
batch_dim=generate_sig.batch_dim,
|
||||
input_spec=generate_sig.input_spec,
|
||||
output_spec=generate_sig.output_spec,
|
||||
)
|
||||
def generate(__self, prompt: str, **kwds: t.Any) -> list[str]:
|
||||
return self.generate(prompt, **kwds)
|
||||
|
||||
@bentoml.Runnable.method(
|
||||
batchable=generate_iterator_sig.batchable,
|
||||
batch_dim=generate_iterator_sig.batch_dim,
|
||||
input_spec=generate_iterator_sig.input_spec,
|
||||
output_spec=generate_iterator_sig.output_spec,
|
||||
)
|
||||
def generate_iterator(__self, prompt: str, **kwds: t.Any) -> t.Iterator[str]:
|
||||
return self.generate_iterator(prompt, **kwds)
|
||||
|
||||
return bentoml.Runner(
|
||||
t.cast(
|
||||
"type[LLMRunnable]",
|
||||
|
||||
@@ -18,6 +18,7 @@ This extends clidantic and BentoML's internal CLI CommandGroup.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import difflib
|
||||
import functools
|
||||
import inspect
|
||||
@@ -270,7 +271,8 @@ def start_model_command(
|
||||
@config.to_click_options
|
||||
@parse_serve_args(_serve_grpc)
|
||||
@OpenLLMCommandGroup.common_chain
|
||||
def model_start(**attrs: t.Any):
|
||||
@click.option("--server-timeout", type=int, default=3600, help="Server timeout in seconds")
|
||||
def model_start(server_timeout: int, **attrs: t.Any):
|
||||
from bentoml._internal.configuration import get_debug_mode
|
||||
from bentoml._internal.log import configure_logging
|
||||
|
||||
@@ -286,12 +288,24 @@ def start_model_command(
|
||||
server_kwds.setdefault("production", not development)
|
||||
|
||||
start_env = os.environ.copy()
|
||||
|
||||
# NOTE: This is a hack to set current configuration
|
||||
_bentoml_config_options = start_env.pop("BENTOML_CONFIG_OPTIONS", "")
|
||||
_bentoml_config_options += (
|
||||
" "
|
||||
if _bentoml_config_options
|
||||
else ""
|
||||
+ f"api_server.timeout={server_timeout}"
|
||||
+ f' runners."llm-{config.__openllm_start_name__}-runner".timeout={config.__openllm_timeout__}'
|
||||
)
|
||||
|
||||
start_env.update(
|
||||
{
|
||||
openllm.utils.FRAMEWORK_ENV_VAR(model_name): openllm.utils.get_framework_env(model_name),
|
||||
openllm.utils.MODEL_CONFIG_ENV_VAR(model_name): nw_config.model_dump_json(),
|
||||
"OPENLLM_MODEL": model_name,
|
||||
"BENTOML_DEBUG": str(get_debug_mode()),
|
||||
"BENTOML_CONFIG_OPTIONS": _bentoml_config_options,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -16,15 +16,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm.types import LLMTokenizer
|
||||
|
||||
|
||||
class DollyV2Config(openllm.LLMConfig):
|
||||
class DollyV2Config(openllm.LLMConfig, default_timeout=3600000):
|
||||
"""Configuration for the dolly-v2 model."""
|
||||
|
||||
return_full_text: bool = False
|
||||
@@ -72,25 +67,3 @@ DEFAULT_PROMPT_TEMPLATE = """{intro}
|
||||
instruction="{instruction}",
|
||||
response_key=RESPONSE_KEY,
|
||||
)
|
||||
|
||||
|
||||
def get_special_token_id(tokenizer: LLMTokenizer, key: str) -> int:
|
||||
"""
|
||||
Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
||||
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
|
||||
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
|
||||
|
||||
Args:
|
||||
tokenizer: the tokenizer
|
||||
key: the key to convert to a single token
|
||||
|
||||
Raises:
|
||||
RuntimeError: if more than one ID was generated
|
||||
|
||||
Returns:
|
||||
int: the token ID for the given key
|
||||
"""
|
||||
token_ids = tokenizer.encode(key)
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
|
||||
return token_ids[0]
|
||||
|
||||
@@ -17,36 +17,49 @@ import logging
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
|
||||
import openllm
|
||||
|
||||
from .configuration_dolly_v2 import (DEFAULT_PROMPT_TEMPLATE, END_KEY,
|
||||
RESPONSE_KEY, get_special_token_id)
|
||||
RESPONSE_KEY)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from openllm.types import LLMTokenizer
|
||||
else:
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_special_token_id(tokenizer: LLMTokenizer, key: str) -> int:
|
||||
"""
|
||||
Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
||||
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
|
||||
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
|
||||
|
||||
Args:
|
||||
tokenizer: the tokenizer
|
||||
key: the key to convert to a single token
|
||||
|
||||
Raises:
|
||||
RuntimeError: if more than one ID was generated
|
||||
|
||||
Returns:
|
||||
int: the token ID for the given key
|
||||
"""
|
||||
token_ids = tokenizer.encode(key)
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
|
||||
return token_ids[0]
|
||||
|
||||
|
||||
class DollyV2(openllm.LLM, _internal=True):
|
||||
default_model = "databricks/dolly-v2-3b"
|
||||
|
||||
variants = ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"]
|
||||
|
||||
def import_model(self, pretrained: str, tag: bentoml.Tag, *args: t.Any, **kwargs: t.Any):
|
||||
return bentoml.transformers.save_model(
|
||||
str(tag),
|
||||
transformers.AutoModelForCausalLM.from_pretrained(
|
||||
pretrained, device_map="auto", torch_dtype=torch.bfloat16
|
||||
),
|
||||
custom_objects={"tokenizer": transformers.AutoTokenizer.from_pretrained(pretrained, padding_size="left")},
|
||||
)
|
||||
import_kwargs = {"device_map": "auto", "torch_dtype": torch.bfloat16, "_tokenizer_padding_size": "left"}
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
@@ -58,7 +71,6 @@ class DollyV2(openllm.LLM, _internal=True):
|
||||
top_k: float | None = None,
|
||||
top_p: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
eos_token_id: int | None = None,
|
||||
**kwargs: t.Any,
|
||||
):
|
||||
"""This is a implementation of InstructionTextGenerationPipeline from databricks."""
|
||||
@@ -67,6 +79,7 @@ class DollyV2(openllm.LLM, _internal=True):
|
||||
)
|
||||
response_key_token_id = None
|
||||
end_key_token_id = None
|
||||
eos_token_id = None
|
||||
|
||||
llm_config = self.config.with_options(
|
||||
max_length=max_length,
|
||||
|
||||
Reference in New Issue
Block a user