mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-19 21:08:22 -05:00
feat: heuristics logprobs (#692)
* fix(encoder): bring back T5 support on PyTorch Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * feat: support logprobs and prompt_logprobs Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * docs: update changelog Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -387,7 +387,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
|
||||
async def generate(
|
||||
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
|
||||
):
|
||||
) -> GenerationOutput:
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
final_result = None
|
||||
@@ -410,7 +410,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
|
||||
async def generate_iterator(
|
||||
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
|
||||
):
|
||||
) -> t.AsyncGenerator[GenerationOutput, None]:
|
||||
from bentoml._internal.runner.runner_handle import DummyRunnerHandle
|
||||
|
||||
if isinstance(self.runner._runner_handle, DummyRunnerHandle):
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm_core._schemas import CompletionChunk, GenerationOutput
|
||||
from openllm_core._schemas import CompletionChunk, GenerationOutput, SampleLogprobs
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import first_not_none, is_vllm_available
|
||||
|
||||
@@ -58,7 +58,7 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
if adapter_name is not None:
|
||||
raise NotImplementedError('Adapter is not supported with vLLM.')
|
||||
stop_: set[str] = set()
|
||||
stop_ = set()
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
@@ -87,45 +87,87 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
self.model = llm.model
|
||||
self.tokenizer = llm.tokenizer
|
||||
self.config = llm.config
|
||||
if hasattr(llm.model, 'device'):
|
||||
self.device = llm.model.device
|
||||
else:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.is_encoder_decoder = llm.model.config.is_encoder_decoder
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
if adapter_name is not None:
|
||||
self.model.set_adapter(adapter_name)
|
||||
async for generation_output in self.forward(prompt_token_ids, request_id, stop=stop, **attrs):
|
||||
yield generation_output.model_dump_json()
|
||||
|
||||
async def forward(self, prompt_token_ids, request_id, stop=None, **attrs):
|
||||
from ._generation import is_partial_stop, prepare_logits_processor
|
||||
|
||||
stop_: set[str] = set()
|
||||
stop_ = set()
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
stop_.update(stop)
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
async for generation_output in self.forward(
|
||||
prompt_token_ids=prompt_token_ids, request_id=request_id, stop=list(stop_), **attrs
|
||||
):
|
||||
yield generation_output.model_dump_json()
|
||||
|
||||
async def forward(self, prompt_token_ids, request_id, stop, **attrs):
|
||||
from ._generation import get_context_length, is_partial_stop, prepare_logits_processor
|
||||
|
||||
max_new_tokens = attrs.pop('max_new_tokens', 256)
|
||||
context_length = attrs.pop('context_length', None)
|
||||
if context_length is None:
|
||||
context_length = get_context_length(self.model.config)
|
||||
if self.model.config.is_encoder_decoder:
|
||||
max_src_len = context_length
|
||||
else:
|
||||
max_src_len = context_length - max_new_tokens - 1
|
||||
prompt_token_ids = prompt_token_ids[-max_src_len:]
|
||||
|
||||
stop_token_ids = [self.tokenizer.encode(it) for it in stop]
|
||||
if self.tokenizer.eos_token_id not in stop_token_ids: # add eos token
|
||||
stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
|
||||
config = self.config.model_construct_env(max_new_tokens=max_new_tokens, **attrs)
|
||||
logits_processor = prepare_logits_processor(config)
|
||||
cumulative_logprob = 0.0
|
||||
|
||||
with torch.inference_mode():
|
||||
# TODO: Support context_length check
|
||||
# context_length: int | None = attrs.pop('context_length', None)
|
||||
# if context_length is None: context_length = get_context_length(self.model.config)
|
||||
# max_src_len = context_length - config['max_new_tokens'] - 1
|
||||
# prompt_token_ids = prompt_token_ids[-max_src_len:]
|
||||
output_token_ids = list(prompt_token_ids)
|
||||
input_len = len(prompt_token_ids)
|
||||
|
||||
logits_processor = prepare_logits_processor(config)
|
||||
if self.is_encoder_decoder:
|
||||
if config['logprobs'] > 0: # FIXME: logprobs is not supported
|
||||
raise NotImplementedError('Logprobs is yet to be supported with encoder-decoder models.')
|
||||
encoder_output = self.model.encoder(input_ids=torch.as_tensor([prompt_token_ids], device=self.device))[0]
|
||||
start_ids = torch.as_tensor(
|
||||
[[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device
|
||||
)
|
||||
else:
|
||||
start_ids = torch.as_tensor([prompt_token_ids], device=self.device)
|
||||
|
||||
past_key_values = out = token = None
|
||||
finish_reason = None
|
||||
prompt_logprobs = []
|
||||
prompt_token_indices = []
|
||||
|
||||
for i in range(config['max_new_tokens']):
|
||||
if i == 0: # prefill
|
||||
out = self.model(torch.as_tensor([prompt_token_ids], device=self.model.device), use_cache=True)
|
||||
else: # decoding
|
||||
out = self.model(
|
||||
torch.as_tensor([[token]], device=self.model.device), use_cache=True, past_key_values=past_key_values
|
||||
if self.is_encoder_decoder:
|
||||
out = self.model.decoder(input_ids=start_ids, encoder_hidden_states=encoder_output, use_cache=True)
|
||||
logits = self.model.lm_head(out[0])
|
||||
else:
|
||||
out = self.model(input_ids=start_ids, use_cache=True)
|
||||
logits = out.logits
|
||||
elif self.is_encoder_decoder: # decoding
|
||||
out = self.model.decoder(
|
||||
input_ids=torch.as_tensor([[token]], device=self.device),
|
||||
encoder_hidden_states=encoder_output,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
logits = out.logits
|
||||
logits = self.model.lm_head(out[0])
|
||||
else:
|
||||
out = self.model(
|
||||
input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
if logits_processor:
|
||||
@@ -138,21 +180,34 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
last_token_logits = logits[0, -1, :]
|
||||
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
if self.model.device.type == 'mps':
|
||||
if self.device.type == 'mps':
|
||||
last_token_logits = last_token_logits.float().to('cpu')
|
||||
|
||||
# TODO: refactor for better sampling logic and apply penalties correctly
|
||||
# support sequence generation, best_of
|
||||
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy
|
||||
_, indices = torch.topk(last_token_logits, 2)
|
||||
tokens = [int(index) for index in indices.tolist()]
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits, dim=-1)
|
||||
probs = torch.softmax(last_token_logits, dim=-1, dtype=torch.float)
|
||||
indices = torch.multinomial(probs, num_samples=2)
|
||||
tokens = [int(token) for token in indices.tolist()]
|
||||
tokens = [int(token) for token in indices.tolist()]
|
||||
|
||||
token = tokens[0]
|
||||
output_token_ids.append(token)
|
||||
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
|
||||
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
|
||||
sample_logprobs: SampleLogprobs = []
|
||||
token_logprobs = logprobs[token].item()
|
||||
cumulative_logprob += token_logprobs
|
||||
|
||||
stopped = False
|
||||
if config['prompt_logprobs']:
|
||||
for token_id in prompt_token_ids:
|
||||
if token_id in prompt_token_indices:
|
||||
continue
|
||||
prompt_token_indices.append(token_id)
|
||||
prompt_logprobs.append({token_id: logprobs[token_id].item()})
|
||||
|
||||
stopped = token in stop_token_ids
|
||||
|
||||
tmp_output_ids, rfind_start = output_token_ids[input_len:], 0
|
||||
# XXX: Move this to API server
|
||||
@@ -162,9 +217,10 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
spaces_between_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
partially_stopped = False
|
||||
if stop_:
|
||||
for it in stop_:
|
||||
if len(stop) > 0:
|
||||
for it in stop:
|
||||
pos = text.rfind(it, rfind_start)
|
||||
if pos != -1:
|
||||
text, stopped = text[:pos], True
|
||||
@@ -173,16 +229,27 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
partially_stopped = is_partial_stop(text, it)
|
||||
if partially_stopped:
|
||||
break
|
||||
|
||||
if config['logprobs']:
|
||||
sample_logprobs.append({token: token_logprobs})
|
||||
|
||||
if not partially_stopped:
|
||||
# TODO: calculate prompt_logprobs
|
||||
yield GenerationOutput(
|
||||
prompt='',
|
||||
finished=False,
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=None
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=output_token_ids[input_len:],
|
||||
cumulative_logprob=cumulative_logprob,
|
||||
logprobs=sample_logprobs if config['logprobs'] else None,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
request_id=request_id,
|
||||
)
|
||||
if stopped:
|
||||
@@ -199,11 +266,13 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=output_token_ids[input_len:],
|
||||
cumulative_logprob=0.0,
|
||||
cumulative_logprob=cumulative_logprob,
|
||||
logprobs=sample_logprobs if config['logprobs'] else None,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -60,6 +60,14 @@ class vLLMRunnable(_Runnable[AsyncLLMEngine]): ...
|
||||
@final
|
||||
class PyTorchRunnable(_Runnable[PreTrainedModel]):
|
||||
tokenizer: Any
|
||||
async def forward(
|
||||
self,
|
||||
prompt_token_ids: List[int],
|
||||
request_id: str,
|
||||
stop: Iterable[str],
|
||||
adapter_name: Optional[str] = ...,
|
||||
**attrs: Any,
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
|
||||
@overload
|
||||
def runnable(backend: Literal['vllm']) -> Type[vLLMRunnable]: ...
|
||||
|
||||
Reference in New Issue
Block a user