feat: heuristics logprobs (#692)

* fix(encoder): bring back T5 support on PyTorch

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* feat: support logprobs and prompt_logprobs

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* docs: update changelog

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-18 19:26:20 -05:00
committed by GitHub
parent 4499469efb
commit 1831d8f129
6 changed files with 135 additions and 54 deletions

View File

@@ -387,7 +387,7 @@ class LLM(t.Generic[M, T], ReprMixin):
async def generate(
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
):
) -> GenerationOutput:
config = self.config.model_construct_env(**attrs)
texts, token_ids = [[]] * config['n'], [[]] * config['n']
final_result = None
@@ -410,7 +410,7 @@ class LLM(t.Generic[M, T], ReprMixin):
async def generate_iterator(
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
):
) -> t.AsyncGenerator[GenerationOutput, None]:
from bentoml._internal.runner.runner_handle import DummyRunnerHandle
if isinstance(self.runner._runner_handle, DummyRunnerHandle):

View File

@@ -8,7 +8,7 @@ import torch
import bentoml
import openllm
from openllm_core._schemas import CompletionChunk, GenerationOutput
from openllm_core._schemas import CompletionChunk, GenerationOutput, SampleLogprobs
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import first_not_none, is_vllm_available
@@ -58,7 +58,7 @@ class vLLMRunnable(bentoml.Runnable):
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
if adapter_name is not None:
raise NotImplementedError('Adapter is not supported with vLLM.')
stop_: set[str] = set()
stop_ = set()
if isinstance(stop, str) and stop != '':
stop_.add(stop)
elif isinstance(stop, t.Iterable):
@@ -87,45 +87,87 @@ class PyTorchRunnable(bentoml.Runnable):
self.model = llm.model
self.tokenizer = llm.tokenizer
self.config = llm.config
if hasattr(llm.model, 'device'):
self.device = llm.model.device
else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.is_encoder_decoder = llm.model.config.is_encoder_decoder
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
if adapter_name is not None:
self.model.set_adapter(adapter_name)
async for generation_output in self.forward(prompt_token_ids, request_id, stop=stop, **attrs):
yield generation_output.model_dump_json()
async def forward(self, prompt_token_ids, request_id, stop=None, **attrs):
from ._generation import is_partial_stop, prepare_logits_processor
stop_: set[str] = set()
stop_ = set()
if isinstance(stop, str) and stop != '':
stop_.add(stop)
elif isinstance(stop, t.Iterable):
stop_.update(stop)
config = self.config.model_construct_env(**attrs)
async for generation_output in self.forward(
prompt_token_ids=prompt_token_ids, request_id=request_id, stop=list(stop_), **attrs
):
yield generation_output.model_dump_json()
async def forward(self, prompt_token_ids, request_id, stop, **attrs):
from ._generation import get_context_length, is_partial_stop, prepare_logits_processor
max_new_tokens = attrs.pop('max_new_tokens', 256)
context_length = attrs.pop('context_length', None)
if context_length is None:
context_length = get_context_length(self.model.config)
if self.model.config.is_encoder_decoder:
max_src_len = context_length
else:
max_src_len = context_length - max_new_tokens - 1
prompt_token_ids = prompt_token_ids[-max_src_len:]
stop_token_ids = [self.tokenizer.encode(it) for it in stop]
if self.tokenizer.eos_token_id not in stop_token_ids: # add eos token
stop_token_ids.append(self.tokenizer.eos_token_id)
config = self.config.model_construct_env(max_new_tokens=max_new_tokens, **attrs)
logits_processor = prepare_logits_processor(config)
cumulative_logprob = 0.0
with torch.inference_mode():
# TODO: Support context_length check
# context_length: int | None = attrs.pop('context_length', None)
# if context_length is None: context_length = get_context_length(self.model.config)
# max_src_len = context_length - config['max_new_tokens'] - 1
# prompt_token_ids = prompt_token_ids[-max_src_len:]
output_token_ids = list(prompt_token_ids)
input_len = len(prompt_token_ids)
logits_processor = prepare_logits_processor(config)
if self.is_encoder_decoder:
if config['logprobs'] > 0: # FIXME: logprobs is not supported
raise NotImplementedError('Logprobs is yet to be supported with encoder-decoder models.')
encoder_output = self.model.encoder(input_ids=torch.as_tensor([prompt_token_ids], device=self.device))[0]
start_ids = torch.as_tensor(
[[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device
)
else:
start_ids = torch.as_tensor([prompt_token_ids], device=self.device)
past_key_values = out = token = None
finish_reason = None
prompt_logprobs = []
prompt_token_indices = []
for i in range(config['max_new_tokens']):
if i == 0: # prefill
out = self.model(torch.as_tensor([prompt_token_ids], device=self.model.device), use_cache=True)
else: # decoding
out = self.model(
torch.as_tensor([[token]], device=self.model.device), use_cache=True, past_key_values=past_key_values
if self.is_encoder_decoder:
out = self.model.decoder(input_ids=start_ids, encoder_hidden_states=encoder_output, use_cache=True)
logits = self.model.lm_head(out[0])
else:
out = self.model(input_ids=start_ids, use_cache=True)
logits = out.logits
elif self.is_encoder_decoder: # decoding
out = self.model.decoder(
input_ids=torch.as_tensor([[token]], device=self.device),
encoder_hidden_states=encoder_output,
past_key_values=past_key_values,
use_cache=True,
)
logits = out.logits
logits = self.model.lm_head(out[0])
else:
out = self.model(
input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True
)
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
@@ -138,21 +180,34 @@ class PyTorchRunnable(bentoml.Runnable):
last_token_logits = logits[0, -1, :]
# Switch to CPU by avoiding some bugs in mps backend.
if self.model.device.type == 'mps':
if self.device.type == 'mps':
last_token_logits = last_token_logits.float().to('cpu')
# TODO: refactor for better sampling logic and apply penalties correctly
# support sequence generation, best_of
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy
_, indices = torch.topk(last_token_logits, 2)
tokens = [int(index) for index in indices.tolist()]
else:
probs = torch.softmax(last_token_logits, dim=-1)
probs = torch.softmax(last_token_logits, dim=-1, dtype=torch.float)
indices = torch.multinomial(probs, num_samples=2)
tokens = [int(token) for token in indices.tolist()]
tokens = [int(token) for token in indices.tolist()]
token = tokens[0]
output_token_ids.append(token)
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
sample_logprobs: SampleLogprobs = []
token_logprobs = logprobs[token].item()
cumulative_logprob += token_logprobs
stopped = False
if config['prompt_logprobs']:
for token_id in prompt_token_ids:
if token_id in prompt_token_indices:
continue
prompt_token_indices.append(token_id)
prompt_logprobs.append({token_id: logprobs[token_id].item()})
stopped = token in stop_token_ids
tmp_output_ids, rfind_start = output_token_ids[input_len:], 0
# XXX: Move this to API server
@@ -162,9 +217,10 @@ class PyTorchRunnable(bentoml.Runnable):
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True,
)
partially_stopped = False
if stop_:
for it in stop_:
if len(stop) > 0:
for it in stop:
pos = text.rfind(it, rfind_start)
if pos != -1:
text, stopped = text[:pos], True
@@ -173,16 +229,27 @@ class PyTorchRunnable(bentoml.Runnable):
partially_stopped = is_partial_stop(text, it)
if partially_stopped:
break
if config['logprobs']:
sample_logprobs.append({token: token_logprobs})
if not partially_stopped:
# TODO: calculate prompt_logprobs
yield GenerationOutput(
prompt='',
finished=False,
outputs=[
CompletionChunk(
index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=None
index=0,
text=text,
token_ids=output_token_ids[input_len:],
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=None,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
request_id=request_id,
)
if stopped:
@@ -199,11 +266,13 @@ class PyTorchRunnable(bentoml.Runnable):
index=0,
text=text,
token_ids=output_token_ids[input_len:],
cumulative_logprob=0.0,
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=finish_reason,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
request_id=request_id,
)

View File

@@ -60,6 +60,14 @@ class vLLMRunnable(_Runnable[AsyncLLMEngine]): ...
@final
class PyTorchRunnable(_Runnable[PreTrainedModel]):
tokenizer: Any
async def forward(
self,
prompt_token_ids: List[int],
request_id: str,
stop: Iterable[str],
adapter_name: Optional[str] = ...,
**attrs: Any,
) -> AsyncGenerator[str, None]: ...
@overload
def runnable(backend: Literal['vllm']) -> Type[vLLMRunnable]: ...