chore(inference): update vllm to 0.2.1.post1 and update config parsing (#554)

chore(dependencies): update vllm to 0.2.1.post1 and update config
parsing
This commit is contained in:
Aaron Pham
2023-11-04 04:01:56 -04:00
committed by GitHub
parent 440e3d646f
commit 72c6005d3b
3 changed files with 23 additions and 3 deletions

View File

@@ -390,6 +390,8 @@ class SamplingParams(ReprMixin):
stop: t.List[str] = dantic.Field(None, description='List of strings that stop the generation when they are generated. The returned output will not contain the stop strings.')
ignore_eos: bool = dantic.Field(False, description='Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.')
logprobs: int = dantic.Field(None, description='Number of log probabilities to return per output token.')
prompt_logprobs: t.Optional[int] = dantic.Field(None, description='Number of log probabilities to return per input token.')
skip_special_tokens: bool = dantic.Field(True, description='Whether to skip special tokens in the generated output.')
if t.TYPE_CHECKING:
max_tokens: int
@@ -407,6 +409,9 @@ class SamplingParams(ReprMixin):
_object_setattr(self, 'temperature', attrs.pop('temperature', 1.0))
_object_setattr(self, 'top_k', attrs.pop('top_k', -1))
_object_setattr(self, 'top_p', attrs.pop('top_p', 1.0))
_object_setattr(self, 'repetition_penalty', attrs.pop('repetition_penalty', 1.0))
_object_setattr(self, 'length_penalty', attrs.pop('length_penalty', 1.0))
_object_setattr(self, 'early_stopping', attrs.pop('early_stopping', False))
self.__attrs_init__(**attrs)
def __getitem__(self, item: str) -> t.Any:
@@ -432,7 +437,18 @@ class SamplingParams(ReprMixin):
top_k = first_not_none(attrs.pop('top_k', None), default=generation_config['top_k'])
top_p = first_not_none(attrs.pop('top_p', None), default=generation_config['top_p'])
max_tokens = first_not_none(attrs.pop('max_tokens', None), attrs.pop('max_new_tokens', None), default=generation_config['max_new_tokens'])
return cls(_internal=True, temperature=temperature, top_k=top_k, top_p=top_p, max_tokens=max_tokens, **attrs)
repetition_penalty = first_not_none(attrs.pop('repetition_penalty', None), default=generation_config['repetition_penalty'])
length_penalty = first_not_none(attrs.pop('length_penalty', None), default=generation_config['length_penalty'])
early_stopping = first_not_none(attrs.pop('early_stopping', None), default=generation_config['early_stopping'])
return cls(_internal=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
early_stopping=early_stopping,
**attrs)
bentoml_cattr.register_unstructure_hook_factory(
lambda cls: attr.has(cls) and lenient_issubclass(cls, SamplingParams), lambda cls: make_dict_unstructure_fn(
@@ -1234,6 +1250,10 @@ class LLMConfig(_ConfigAttr):
def __getitem__(self, item: t.Literal['ignore_eos']) -> bool: ...
@overload
def __getitem__(self, item: t.Literal['logprobs']) -> int: ...
@overload
def __getitem__(self, item: t.Literal['prompt_logprobs']) -> t.Optional[int]: ...
@overload
def __getitem__(self, item: t.Literal['skip_special_tokens']) -> bool: ...
# NOTE: PeftType arguments
@overload
def __getitem__(self, item: t.Literal['prompt_tuning']) -> dict[str, t.Any]: ...

View File

@@ -115,7 +115,7 @@ openai = ["openai[embeddings]", "tiktoken"]
opt = ["flax>=0.7", "jax", "jaxlib", "tensorflow", "keras"]
playground = ["jupyter", "notebook", "ipython", "jupytext", "nbformat"]
starcoder = ["bitsandbytes"]
vllm = ["vllm>=0.2.0", "ray"]
vllm = ["vllm>=0.2.1post1", "ray"]
[tool.hatch.version]
fallback-version = "0.0.0"

View File

@@ -133,7 +133,7 @@ AGENTS_DEPS = ['transformers[agents]>=4.30', 'diffusers', 'soundfile']
PLAYGROUND_DEPS = ['jupyter', 'notebook', 'ipython', 'jupytext', 'nbformat']
GGML_DEPS = ['ctransformers']
GPTQ_DEPS = ['auto-gptq[triton]>=0.4.2', 'optimum>=1.12.0']
VLLM_DEPS = ['vllm>=0.2.0', 'ray']
VLLM_DEPS = ['vllm>=0.2.1post1', 'ray']
_base_requirements: dict[str, t.Any] = {
inflection.dasherize(name): config_cls.__openllm_requirements__ for name, config_cls in openllm.CONFIG_MAPPING.items() if config_cls.__openllm_requirements__