feat: heuristics logprobs (#692)

* fix(encoder): bring back T5 support on PyTorch

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* feat: support logprobs and prompt_logprobs

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* docs: update changelog

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-18 19:26:20 -05:00
committed by GitHub
parent 4499469efb
commit 1831d8f129
6 changed files with 135 additions and 54 deletions

View File

@@ -0,0 +1,3 @@
PyTorch runners now supports logprobs calculation for the `logits` output.
Update logits calculation to support encoder-decoder models (which fix T5 inference)

View File

@@ -195,6 +195,9 @@ class GenerationConfig(ReprMixin):
decoder_start_token_id: int = dantic.Field(
description='If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.'
)
# NOTE: This is now implemented and supported for both PyTorch and vLLM
logprobs: int = dantic.Field(0, description='Number of log probabilities to return per output token.')
prompt_logprobs: int = dantic.Field(0, description='Number of log probabilities to return per input token.')
def __init__(self, *, _internal: bool = False, **attrs: t.Any):
if not _internal:
@@ -224,7 +227,7 @@ converter.register_unstructure_hook_factory(
),
)
_GenerationConfigT = t.TypeVar('_GenerationConfig', bound=GenerationConfig)
_GenerationConfigT = t.TypeVar('_GenerationConfigT', bound=GenerationConfig)
@attr.frozen(slots=True, repr=False, init=False)
@@ -256,10 +259,6 @@ class SamplingParams(ReprMixin):
False,
description='Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.',
)
logprobs: int = dantic.Field(None, description='Number of log probabilities to return per output token.')
prompt_logprobs: t.Optional[int] = dantic.Field(
None, description='Number of log probabilities to return per input token.'
)
skip_special_tokens: bool = dantic.Field(True, description='Whether to skip special tokens in the generated output.')
# space_between_special_tokens: bool = dantic.Field(True, description='Whether to add a space between special tokens in the generated output.')
@@ -268,6 +267,7 @@ class SamplingParams(ReprMixin):
temperature: float
top_k: int
top_p: float
logprobs: int
def __init__(self, *, _internal: bool = False, **attrs: t.Any):
if not _internal:
@@ -319,8 +319,6 @@ class SamplingParams(ReprMixin):
)
length_penalty = first_not_none(attrs.pop('length_penalty', None), default=generation_config['length_penalty'])
early_stopping = first_not_none(attrs.pop('early_stopping', None), default=generation_config['early_stopping'])
logprobs = first_not_none(attrs.pop('logprobs', None), default=None)
prompt_logprobs = first_not_none(attrs.pop('prompt_logprobs', None), default=None)
return cls(
_internal=True,
@@ -331,8 +329,6 @@ class SamplingParams(ReprMixin):
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
early_stopping=early_stopping,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs,
**attrs,
)
@@ -1213,6 +1209,10 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]):
def __getitem__(self,item:t.Literal['encoder_no_repeat_ngram_size'])->int:...
@overload
def __getitem__(self,item:t.Literal['decoder_start_token_id'])->int:...
@overload
def __getitem__(self,item:t.Literal['logprobs'])->int:...
@overload
def __getitem__(self,item:t.Literal['prompt_logprobs'])->int:...
# NOTE: SamplingParams arguments
@overload
def __getitem__(self,item:t.Literal['n'])->int:...
@@ -1227,10 +1227,6 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]):
@overload
def __getitem__(self,item:t.Literal['ignore_eos'])->bool:...
@overload
def __getitem__(self,item:t.Literal['logprobs'])->int:...
@overload
def __getitem__(self,item:t.Literal['prompt_logprobs'])->t.Optional[int]:...
@overload
def __getitem__(self,item:t.Literal['skip_special_tokens'])->bool:...
# NOTE: PeftType arguments
@overload

View File

@@ -1,5 +1,3 @@
"""Schema definition for OpenLLM. This schema is used throughout openllm core components library."""
from __future__ import annotations
import typing as t
@@ -9,7 +7,7 @@ import orjson
from ._configuration import LLMConfig
from .config import AutoConfig
from .utils import converter, gen_random_uuid
from .utils import ReprMixin, converter, gen_random_uuid
if t.TYPE_CHECKING:
import vllm
@@ -17,8 +15,15 @@ if t.TYPE_CHECKING:
from ._typing_compat import Self
@attr.define
class _SchemaMixin:
@attr.define(repr=False)
class _SchemaMixin(ReprMixin):
@property
def __repr_keys__(self):
return list(attr.fields_dict(self.__class__))
def __repr_args__(self):
yield from ((k, getattr(self, k)) for k in self.__repr_keys__)
def model_dump(self) -> dict[str, t.Any]:
return converter.unstructure(self)
@@ -29,7 +34,7 @@ class _SchemaMixin:
return attr.evolve(self, **options)
@attr.define
@attr.define(repr=False)
class MetadataOutput(_SchemaMixin):
model_id: str
timeout: int
@@ -51,7 +56,7 @@ class MetadataOutput(_SchemaMixin):
}
@attr.define
@attr.define(repr=False)
class GenerationInput(_SchemaMixin):
prompt: str
llm_config: LLMConfig
@@ -111,7 +116,7 @@ PromptLogprobs = t.List[t.Optional[t.Dict[int, float]]]
FinishReason = t.Literal['length', 'stop']
@attr.define
@attr.define(repr=False)
class CompletionChunk(_SchemaMixin):
index: int
text: str
@@ -124,7 +129,7 @@ class CompletionChunk(_SchemaMixin):
return orjson.dumps(self.model_dump(), option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
@attr.define
@attr.define(repr=False)
class GenerationOutput(_SchemaMixin):
prompt: str
finished: bool

View File

@@ -387,7 +387,7 @@ class LLM(t.Generic[M, T], ReprMixin):
async def generate(
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
):
) -> GenerationOutput:
config = self.config.model_construct_env(**attrs)
texts, token_ids = [[]] * config['n'], [[]] * config['n']
final_result = None
@@ -410,7 +410,7 @@ class LLM(t.Generic[M, T], ReprMixin):
async def generate_iterator(
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
):
) -> t.AsyncGenerator[GenerationOutput, None]:
from bentoml._internal.runner.runner_handle import DummyRunnerHandle
if isinstance(self.runner._runner_handle, DummyRunnerHandle):

View File

@@ -8,7 +8,7 @@ import torch
import bentoml
import openllm
from openllm_core._schemas import CompletionChunk, GenerationOutput
from openllm_core._schemas import CompletionChunk, GenerationOutput, SampleLogprobs
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import first_not_none, is_vllm_available
@@ -58,7 +58,7 @@ class vLLMRunnable(bentoml.Runnable):
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
if adapter_name is not None:
raise NotImplementedError('Adapter is not supported with vLLM.')
stop_: set[str] = set()
stop_ = set()
if isinstance(stop, str) and stop != '':
stop_.add(stop)
elif isinstance(stop, t.Iterable):
@@ -87,45 +87,87 @@ class PyTorchRunnable(bentoml.Runnable):
self.model = llm.model
self.tokenizer = llm.tokenizer
self.config = llm.config
if hasattr(llm.model, 'device'):
self.device = llm.model.device
else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.is_encoder_decoder = llm.model.config.is_encoder_decoder
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
if adapter_name is not None:
self.model.set_adapter(adapter_name)
async for generation_output in self.forward(prompt_token_ids, request_id, stop=stop, **attrs):
yield generation_output.model_dump_json()
async def forward(self, prompt_token_ids, request_id, stop=None, **attrs):
from ._generation import is_partial_stop, prepare_logits_processor
stop_: set[str] = set()
stop_ = set()
if isinstance(stop, str) and stop != '':
stop_.add(stop)
elif isinstance(stop, t.Iterable):
stop_.update(stop)
config = self.config.model_construct_env(**attrs)
async for generation_output in self.forward(
prompt_token_ids=prompt_token_ids, request_id=request_id, stop=list(stop_), **attrs
):
yield generation_output.model_dump_json()
async def forward(self, prompt_token_ids, request_id, stop, **attrs):
from ._generation import get_context_length, is_partial_stop, prepare_logits_processor
max_new_tokens = attrs.pop('max_new_tokens', 256)
context_length = attrs.pop('context_length', None)
if context_length is None:
context_length = get_context_length(self.model.config)
if self.model.config.is_encoder_decoder:
max_src_len = context_length
else:
max_src_len = context_length - max_new_tokens - 1
prompt_token_ids = prompt_token_ids[-max_src_len:]
stop_token_ids = [self.tokenizer.encode(it) for it in stop]
if self.tokenizer.eos_token_id not in stop_token_ids: # add eos token
stop_token_ids.append(self.tokenizer.eos_token_id)
config = self.config.model_construct_env(max_new_tokens=max_new_tokens, **attrs)
logits_processor = prepare_logits_processor(config)
cumulative_logprob = 0.0
with torch.inference_mode():
# TODO: Support context_length check
# context_length: int | None = attrs.pop('context_length', None)
# if context_length is None: context_length = get_context_length(self.model.config)
# max_src_len = context_length - config['max_new_tokens'] - 1
# prompt_token_ids = prompt_token_ids[-max_src_len:]
output_token_ids = list(prompt_token_ids)
input_len = len(prompt_token_ids)
logits_processor = prepare_logits_processor(config)
if self.is_encoder_decoder:
if config['logprobs'] > 0: # FIXME: logprobs is not supported
raise NotImplementedError('Logprobs is yet to be supported with encoder-decoder models.')
encoder_output = self.model.encoder(input_ids=torch.as_tensor([prompt_token_ids], device=self.device))[0]
start_ids = torch.as_tensor(
[[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device
)
else:
start_ids = torch.as_tensor([prompt_token_ids], device=self.device)
past_key_values = out = token = None
finish_reason = None
prompt_logprobs = []
prompt_token_indices = []
for i in range(config['max_new_tokens']):
if i == 0: # prefill
out = self.model(torch.as_tensor([prompt_token_ids], device=self.model.device), use_cache=True)
else: # decoding
out = self.model(
torch.as_tensor([[token]], device=self.model.device), use_cache=True, past_key_values=past_key_values
if self.is_encoder_decoder:
out = self.model.decoder(input_ids=start_ids, encoder_hidden_states=encoder_output, use_cache=True)
logits = self.model.lm_head(out[0])
else:
out = self.model(input_ids=start_ids, use_cache=True)
logits = out.logits
elif self.is_encoder_decoder: # decoding
out = self.model.decoder(
input_ids=torch.as_tensor([[token]], device=self.device),
encoder_hidden_states=encoder_output,
past_key_values=past_key_values,
use_cache=True,
)
logits = out.logits
logits = self.model.lm_head(out[0])
else:
out = self.model(
input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True
)
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
@@ -138,21 +180,34 @@ class PyTorchRunnable(bentoml.Runnable):
last_token_logits = logits[0, -1, :]
# Switch to CPU by avoiding some bugs in mps backend.
if self.model.device.type == 'mps':
if self.device.type == 'mps':
last_token_logits = last_token_logits.float().to('cpu')
# TODO: refactor for better sampling logic and apply penalties correctly
# support sequence generation, best_of
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy
_, indices = torch.topk(last_token_logits, 2)
tokens = [int(index) for index in indices.tolist()]
else:
probs = torch.softmax(last_token_logits, dim=-1)
probs = torch.softmax(last_token_logits, dim=-1, dtype=torch.float)
indices = torch.multinomial(probs, num_samples=2)
tokens = [int(token) for token in indices.tolist()]
tokens = [int(token) for token in indices.tolist()]
token = tokens[0]
output_token_ids.append(token)
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
sample_logprobs: SampleLogprobs = []
token_logprobs = logprobs[token].item()
cumulative_logprob += token_logprobs
stopped = False
if config['prompt_logprobs']:
for token_id in prompt_token_ids:
if token_id in prompt_token_indices:
continue
prompt_token_indices.append(token_id)
prompt_logprobs.append({token_id: logprobs[token_id].item()})
stopped = token in stop_token_ids
tmp_output_ids, rfind_start = output_token_ids[input_len:], 0
# XXX: Move this to API server
@@ -162,9 +217,10 @@ class PyTorchRunnable(bentoml.Runnable):
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True,
)
partially_stopped = False
if stop_:
for it in stop_:
if len(stop) > 0:
for it in stop:
pos = text.rfind(it, rfind_start)
if pos != -1:
text, stopped = text[:pos], True
@@ -173,16 +229,27 @@ class PyTorchRunnable(bentoml.Runnable):
partially_stopped = is_partial_stop(text, it)
if partially_stopped:
break
if config['logprobs']:
sample_logprobs.append({token: token_logprobs})
if not partially_stopped:
# TODO: calculate prompt_logprobs
yield GenerationOutput(
prompt='',
finished=False,
outputs=[
CompletionChunk(
index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=None
index=0,
text=text,
token_ids=output_token_ids[input_len:],
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=None,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
request_id=request_id,
)
if stopped:
@@ -199,11 +266,13 @@ class PyTorchRunnable(bentoml.Runnable):
index=0,
text=text,
token_ids=output_token_ids[input_len:],
cumulative_logprob=0.0,
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=finish_reason,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
request_id=request_id,
)

View File

@@ -60,6 +60,14 @@ class vLLMRunnable(_Runnable[AsyncLLMEngine]): ...
@final
class PyTorchRunnable(_Runnable[PreTrainedModel]):
tokenizer: Any
async def forward(
self,
prompt_token_ids: List[int],
request_id: str,
stop: Iterable[str],
adapter_name: Optional[str] = ...,
**attrs: Any,
) -> AsyncGenerator[str, None]: ...
@overload
def runnable(backend: Literal['vllm']) -> Type[vLLMRunnable]: ...