From 1831d8f129f398ced9727d92425a037853e22bf0 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Sat, 18 Nov 2023 19:26:20 -0500 Subject: [PATCH] feat: heuristics logprobs (#692) * fix(encoder): bring back T5 support on PyTorch Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * feat: support logprobs and prompt_logprobs Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * docs: update changelog Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- changelog.d/692.feature.md | 3 + .../src/openllm_core/_configuration.py | 22 ++- openllm-core/src/openllm_core/_schemas.py | 23 ++-- openllm-python/src/openllm/_llm.py | 4 +- openllm-python/src/openllm/_runners.py | 129 ++++++++++++++---- openllm-python/src/openllm/_runners.pyi | 8 ++ 6 files changed, 135 insertions(+), 54 deletions(-) create mode 100644 changelog.d/692.feature.md diff --git a/changelog.d/692.feature.md b/changelog.d/692.feature.md new file mode 100644 index 00000000..d7c0c100 --- /dev/null +++ b/changelog.d/692.feature.md @@ -0,0 +1,3 @@ +PyTorch runners now supports logprobs calculation for the `logits` output. + +Update logits calculation to support encoder-decoder models (which fix T5 inference) diff --git a/openllm-core/src/openllm_core/_configuration.py b/openllm-core/src/openllm_core/_configuration.py index eef5bc22..81adefcd 100644 --- a/openllm-core/src/openllm_core/_configuration.py +++ b/openllm-core/src/openllm_core/_configuration.py @@ -195,6 +195,9 @@ class GenerationConfig(ReprMixin): decoder_start_token_id: int = dantic.Field( description='If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.' ) + # NOTE: This is now implemented and supported for both PyTorch and vLLM + logprobs: int = dantic.Field(0, description='Number of log probabilities to return per output token.') + prompt_logprobs: int = dantic.Field(0, description='Number of log probabilities to return per input token.') def __init__(self, *, _internal: bool = False, **attrs: t.Any): if not _internal: @@ -224,7 +227,7 @@ converter.register_unstructure_hook_factory( ), ) -_GenerationConfigT = t.TypeVar('_GenerationConfig', bound=GenerationConfig) +_GenerationConfigT = t.TypeVar('_GenerationConfigT', bound=GenerationConfig) @attr.frozen(slots=True, repr=False, init=False) @@ -256,10 +259,6 @@ class SamplingParams(ReprMixin): False, description='Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.', ) - logprobs: int = dantic.Field(None, description='Number of log probabilities to return per output token.') - prompt_logprobs: t.Optional[int] = dantic.Field( - None, description='Number of log probabilities to return per input token.' - ) skip_special_tokens: bool = dantic.Field(True, description='Whether to skip special tokens in the generated output.') # space_between_special_tokens: bool = dantic.Field(True, description='Whether to add a space between special tokens in the generated output.') @@ -268,6 +267,7 @@ class SamplingParams(ReprMixin): temperature: float top_k: int top_p: float + logprobs: int def __init__(self, *, _internal: bool = False, **attrs: t.Any): if not _internal: @@ -319,8 +319,6 @@ class SamplingParams(ReprMixin): ) length_penalty = first_not_none(attrs.pop('length_penalty', None), default=generation_config['length_penalty']) early_stopping = first_not_none(attrs.pop('early_stopping', None), default=generation_config['early_stopping']) - logprobs = first_not_none(attrs.pop('logprobs', None), default=None) - prompt_logprobs = first_not_none(attrs.pop('prompt_logprobs', None), default=None) return cls( _internal=True, @@ -331,8 +329,6 @@ class SamplingParams(ReprMixin): repetition_penalty=repetition_penalty, length_penalty=length_penalty, early_stopping=early_stopping, - logprobs=logprobs, - prompt_logprobs=prompt_logprobs, **attrs, ) @@ -1213,6 +1209,10 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): def __getitem__(self,item:t.Literal['encoder_no_repeat_ngram_size'])->int:... @overload def __getitem__(self,item:t.Literal['decoder_start_token_id'])->int:... + @overload + def __getitem__(self,item:t.Literal['logprobs'])->int:... + @overload + def __getitem__(self,item:t.Literal['prompt_logprobs'])->int:... # NOTE: SamplingParams arguments @overload def __getitem__(self,item:t.Literal['n'])->int:... @@ -1227,10 +1227,6 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): @overload def __getitem__(self,item:t.Literal['ignore_eos'])->bool:... @overload - def __getitem__(self,item:t.Literal['logprobs'])->int:... - @overload - def __getitem__(self,item:t.Literal['prompt_logprobs'])->t.Optional[int]:... - @overload def __getitem__(self,item:t.Literal['skip_special_tokens'])->bool:... # NOTE: PeftType arguments @overload diff --git a/openllm-core/src/openllm_core/_schemas.py b/openllm-core/src/openllm_core/_schemas.py index 4d447e2d..23e54420 100644 --- a/openllm-core/src/openllm_core/_schemas.py +++ b/openllm-core/src/openllm_core/_schemas.py @@ -1,5 +1,3 @@ -"""Schema definition for OpenLLM. This schema is used throughout openllm core components library.""" - from __future__ import annotations import typing as t @@ -9,7 +7,7 @@ import orjson from ._configuration import LLMConfig from .config import AutoConfig -from .utils import converter, gen_random_uuid +from .utils import ReprMixin, converter, gen_random_uuid if t.TYPE_CHECKING: import vllm @@ -17,8 +15,15 @@ if t.TYPE_CHECKING: from ._typing_compat import Self -@attr.define -class _SchemaMixin: +@attr.define(repr=False) +class _SchemaMixin(ReprMixin): + @property + def __repr_keys__(self): + return list(attr.fields_dict(self.__class__)) + + def __repr_args__(self): + yield from ((k, getattr(self, k)) for k in self.__repr_keys__) + def model_dump(self) -> dict[str, t.Any]: return converter.unstructure(self) @@ -29,7 +34,7 @@ class _SchemaMixin: return attr.evolve(self, **options) -@attr.define +@attr.define(repr=False) class MetadataOutput(_SchemaMixin): model_id: str timeout: int @@ -51,7 +56,7 @@ class MetadataOutput(_SchemaMixin): } -@attr.define +@attr.define(repr=False) class GenerationInput(_SchemaMixin): prompt: str llm_config: LLMConfig @@ -111,7 +116,7 @@ PromptLogprobs = t.List[t.Optional[t.Dict[int, float]]] FinishReason = t.Literal['length', 'stop'] -@attr.define +@attr.define(repr=False) class CompletionChunk(_SchemaMixin): index: int text: str @@ -124,7 +129,7 @@ class CompletionChunk(_SchemaMixin): return orjson.dumps(self.model_dump(), option=orjson.OPT_NON_STR_KEYS).decode('utf-8') -@attr.define +@attr.define(repr=False) class GenerationOutput(_SchemaMixin): prompt: str finished: bool diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 3b4807eb..ed086894 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -387,7 +387,7 @@ class LLM(t.Generic[M, T], ReprMixin): async def generate( self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs - ): + ) -> GenerationOutput: config = self.config.model_construct_env(**attrs) texts, token_ids = [[]] * config['n'], [[]] * config['n'] final_result = None @@ -410,7 +410,7 @@ class LLM(t.Generic[M, T], ReprMixin): async def generate_iterator( self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs - ): + ) -> t.AsyncGenerator[GenerationOutput, None]: from bentoml._internal.runner.runner_handle import DummyRunnerHandle if isinstance(self.runner._runner_handle, DummyRunnerHandle): diff --git a/openllm-python/src/openllm/_runners.py b/openllm-python/src/openllm/_runners.py index 16b26916..b3c2982c 100644 --- a/openllm-python/src/openllm/_runners.py +++ b/openllm-python/src/openllm/_runners.py @@ -8,7 +8,7 @@ import torch import bentoml import openllm -from openllm_core._schemas import CompletionChunk, GenerationOutput +from openllm_core._schemas import CompletionChunk, GenerationOutput, SampleLogprobs from openllm_core.exceptions import OpenLLMException from openllm_core.utils import first_not_none, is_vllm_available @@ -58,7 +58,7 @@ class vLLMRunnable(bentoml.Runnable): async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): if adapter_name is not None: raise NotImplementedError('Adapter is not supported with vLLM.') - stop_: set[str] = set() + stop_ = set() if isinstance(stop, str) and stop != '': stop_.add(stop) elif isinstance(stop, t.Iterable): @@ -87,45 +87,87 @@ class PyTorchRunnable(bentoml.Runnable): self.model = llm.model self.tokenizer = llm.tokenizer self.config = llm.config + if hasattr(llm.model, 'device'): + self.device = llm.model.device + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.is_encoder_decoder = llm.model.config.is_encoder_decoder @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): if adapter_name is not None: self.model.set_adapter(adapter_name) - async for generation_output in self.forward(prompt_token_ids, request_id, stop=stop, **attrs): - yield generation_output.model_dump_json() - - async def forward(self, prompt_token_ids, request_id, stop=None, **attrs): - from ._generation import is_partial_stop, prepare_logits_processor - - stop_: set[str] = set() + stop_ = set() if isinstance(stop, str) and stop != '': stop_.add(stop) elif isinstance(stop, t.Iterable): stop_.update(stop) - config = self.config.model_construct_env(**attrs) + async for generation_output in self.forward( + prompt_token_ids=prompt_token_ids, request_id=request_id, stop=list(stop_), **attrs + ): + yield generation_output.model_dump_json() + + async def forward(self, prompt_token_ids, request_id, stop, **attrs): + from ._generation import get_context_length, is_partial_stop, prepare_logits_processor + + max_new_tokens = attrs.pop('max_new_tokens', 256) + context_length = attrs.pop('context_length', None) + if context_length is None: + context_length = get_context_length(self.model.config) + if self.model.config.is_encoder_decoder: + max_src_len = context_length + else: + max_src_len = context_length - max_new_tokens - 1 + prompt_token_ids = prompt_token_ids[-max_src_len:] + + stop_token_ids = [self.tokenizer.encode(it) for it in stop] + if self.tokenizer.eos_token_id not in stop_token_ids: # add eos token + stop_token_ids.append(self.tokenizer.eos_token_id) + + config = self.config.model_construct_env(max_new_tokens=max_new_tokens, **attrs) + logits_processor = prepare_logits_processor(config) + cumulative_logprob = 0.0 with torch.inference_mode(): - # TODO: Support context_length check - # context_length: int | None = attrs.pop('context_length', None) - # if context_length is None: context_length = get_context_length(self.model.config) - # max_src_len = context_length - config['max_new_tokens'] - 1 - # prompt_token_ids = prompt_token_ids[-max_src_len:] output_token_ids = list(prompt_token_ids) input_len = len(prompt_token_ids) - logits_processor = prepare_logits_processor(config) + if self.is_encoder_decoder: + if config['logprobs'] > 0: # FIXME: logprobs is not supported + raise NotImplementedError('Logprobs is yet to be supported with encoder-decoder models.') + encoder_output = self.model.encoder(input_ids=torch.as_tensor([prompt_token_ids], device=self.device))[0] + start_ids = torch.as_tensor( + [[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device + ) + else: + start_ids = torch.as_tensor([prompt_token_ids], device=self.device) past_key_values = out = token = None finish_reason = None + prompt_logprobs = [] + prompt_token_indices = [] + for i in range(config['max_new_tokens']): if i == 0: # prefill - out = self.model(torch.as_tensor([prompt_token_ids], device=self.model.device), use_cache=True) - else: # decoding - out = self.model( - torch.as_tensor([[token]], device=self.model.device), use_cache=True, past_key_values=past_key_values + if self.is_encoder_decoder: + out = self.model.decoder(input_ids=start_ids, encoder_hidden_states=encoder_output, use_cache=True) + logits = self.model.lm_head(out[0]) + else: + out = self.model(input_ids=start_ids, use_cache=True) + logits = out.logits + elif self.is_encoder_decoder: # decoding + out = self.model.decoder( + input_ids=torch.as_tensor([[token]], device=self.device), + encoder_hidden_states=encoder_output, + past_key_values=past_key_values, + use_cache=True, ) - logits = out.logits + logits = self.model.lm_head(out[0]) + else: + out = self.model( + input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True + ) + logits = out.logits past_key_values = out.past_key_values if logits_processor: @@ -138,21 +180,34 @@ class PyTorchRunnable(bentoml.Runnable): last_token_logits = logits[0, -1, :] # Switch to CPU by avoiding some bugs in mps backend. - if self.model.device.type == 'mps': + if self.device.type == 'mps': last_token_logits = last_token_logits.float().to('cpu') + # TODO: refactor for better sampling logic and apply penalties correctly + # support sequence generation, best_of if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy _, indices = torch.topk(last_token_logits, 2) - tokens = [int(index) for index in indices.tolist()] else: - probs = torch.softmax(last_token_logits, dim=-1) + probs = torch.softmax(last_token_logits, dim=-1, dtype=torch.float) indices = torch.multinomial(probs, num_samples=2) - tokens = [int(token) for token in indices.tolist()] + tokens = [int(token) for token in indices.tolist()] token = tokens[0] output_token_ids.append(token) + # NOTE: We can't use last_token_logits since logprobs is based on raw logits + logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float) + sample_logprobs: SampleLogprobs = [] + token_logprobs = logprobs[token].item() + cumulative_logprob += token_logprobs - stopped = False + if config['prompt_logprobs']: + for token_id in prompt_token_ids: + if token_id in prompt_token_indices: + continue + prompt_token_indices.append(token_id) + prompt_logprobs.append({token_id: logprobs[token_id].item()}) + + stopped = token in stop_token_ids tmp_output_ids, rfind_start = output_token_ids[input_len:], 0 # XXX: Move this to API server @@ -162,9 +217,10 @@ class PyTorchRunnable(bentoml.Runnable): spaces_between_special_tokens=False, clean_up_tokenization_spaces=True, ) + partially_stopped = False - if stop_: - for it in stop_: + if len(stop) > 0: + for it in stop: pos = text.rfind(it, rfind_start) if pos != -1: text, stopped = text[:pos], True @@ -173,16 +229,27 @@ class PyTorchRunnable(bentoml.Runnable): partially_stopped = is_partial_stop(text, it) if partially_stopped: break + + if config['logprobs']: + sample_logprobs.append({token: token_logprobs}) + if not partially_stopped: + # TODO: calculate prompt_logprobs yield GenerationOutput( prompt='', finished=False, outputs=[ CompletionChunk( - index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=None + index=0, + text=text, + token_ids=output_token_ids[input_len:], + cumulative_logprob=cumulative_logprob, + logprobs=sample_logprobs if config['logprobs'] else None, + finish_reason=None, ) ], prompt_token_ids=prompt_token_ids, + prompt_logprobs=prompt_logprobs, request_id=request_id, ) if stopped: @@ -199,11 +266,13 @@ class PyTorchRunnable(bentoml.Runnable): index=0, text=text, token_ids=output_token_ids[input_len:], - cumulative_logprob=0.0, + cumulative_logprob=cumulative_logprob, + logprobs=sample_logprobs if config['logprobs'] else None, finish_reason=finish_reason, ) ], prompt_token_ids=prompt_token_ids, + prompt_logprobs=prompt_logprobs, request_id=request_id, ) diff --git a/openllm-python/src/openllm/_runners.pyi b/openllm-python/src/openllm/_runners.pyi index a1ab4d7f..ece66350 100644 --- a/openllm-python/src/openllm/_runners.pyi +++ b/openllm-python/src/openllm/_runners.pyi @@ -60,6 +60,14 @@ class vLLMRunnable(_Runnable[AsyncLLMEngine]): ... @final class PyTorchRunnable(_Runnable[PreTrainedModel]): tokenizer: Any + async def forward( + self, + prompt_token_ids: List[int], + request_id: str, + stop: Iterable[str], + adapter_name: Optional[str] = ..., + **attrs: Any, + ) -> AsyncGenerator[str, None]: ... @overload def runnable(backend: Literal['vllm']) -> Type[vLLMRunnable]: ...