Files
OpenLLM/openllm-python/src/openllm/_runners.py
Zhao Shenyang aff5dc8ff2 fix: proper SSE handling for vllm (#877)
fix: proper SSE handling
2024-02-02 17:25:58 +08:00

334 lines
13 KiB
Python

from __future__ import annotations
import gc, traceback, types, typing as t
import torch, bentoml, openllm
from openllm_core._schemas import CompletionChunk, GenerationOutput, SampleLogprobs
from openllm_core.utils import ReprMixin, is_ctranslate_available, is_vllm_available
if t.TYPE_CHECKING:
from openllm_core._typing_compat import M, T
from ._runners import Runner
_registry = {}
__all__ = ['runner']
def registry(cls=None, *, alias=None):
def decorator(_cls):
_registry[_cls.__name__[:-8].lower() if alias is None else alias] = _cls
return _cls
if cls is None:
return decorator
return decorator(cls)
def runner(llm: openllm.LLM[M, T]) -> Runner[M, T]:
try:
assert llm.bentomodel
except (bentoml.exceptions.NotFound, AssertionError) as err:
raise RuntimeError(f'Failed to locate {llm.bentomodel}: {err}') from err
return types.new_class(
llm.config.__class__.__name__[:-6] + 'Runner',
(bentoml.Runner,), #
exec_body=lambda ns: ns.update({
'llm_type': llm.llm_type,
'identifying_params': llm.identifying_params, #
'llm_tag': llm.tag,
'llm': llm,
'config': llm.config,
'backend': llm.__llm_backend__, #
'__module__': llm.__module__,
'__repr__': ReprMixin.__repr__, #
'__doc__': llm.config.__class__.__doc__ or f'Generated Runner class for {llm.config["model_name"]}',
'__repr_keys__': property(lambda _: {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}),
'__repr_args__': lambda _: (
(
'runner_methods',
{
method.name: {'batchable': method.config.batchable, 'batch_dim': method.config.batch_dim if method.config.batchable else None}
for method in _.runner_methods
},
),
('config', llm.config.model_dump(flatten=True)),
('llm_type', llm.llm_type),
('backend', llm.__llm_backend__),
('llm_tag', llm.tag),
),
'has_adapters': llm.has_adapters,
'template': llm.config.template,
'system_message': llm.config.system_message,
}),
)(
_registry[llm.__llm_backend__],
name=f"llm-{llm.config['start_name']}-runner",
models=[llm.bentomodel],
scheduling_strategy=openllm.CascadingResourceStrategy,
runnable_init_params={'llm': llm},
)
@registry
class CTranslateRunnable(bentoml.Runnable):
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu')
SUPPORTS_CPU_MULTI_THREADING = True
def __init__(self, llm):
if not is_ctranslate_available():
raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
config, sampling_params = self.config.model_construct_env(stop=list(stop), **attrs).inference_options(self.llm)
cumulative_logprob, output_token_ids, input_len = 0.0, list(prompt_token_ids), len(prompt_token_ids)
tokens = self.tokenizer.convert_ids_to_tokens(prompt_token_ids)
async for request_output in self.model.async_generate_tokens(tokens, **sampling_params):
if config['logprobs']:
cumulative_logprob += request_output.log_prob
output_token_ids.append(request_output.token_id)
text = self.tokenizer.decode(
output_token_ids[input_len:],
skip_special_tokens=True, #
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True, #
)
yield GenerationOutput(
prompt_token_ids=prompt_token_ids, #
prompt='',
finished=request_output.is_last,
request_id=request_id, #
outputs=[
CompletionChunk(
index=0,
text=text,
finish_reason=None, #
token_ids=output_token_ids[input_len:],
cumulative_logprob=cumulative_logprob, #
# TODO: logprobs, but seems like we don't have access to the raw logits
)
],
).model_dump_json()
@registry
class vLLMRunnable(bentoml.Runnable):
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
SUPPORTS_CPU_MULTI_THREADING = True
def __init__(self, llm):
if not is_vllm_available():
raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
import vllm
self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer
num_gpus, dev = 1, openllm.utils.device_count()
if dev >= 2:
num_gpus = min(dev // 2 * 2, dev)
quantise = llm.quantise if llm.quantise and llm.quantise in {'gptq', 'awq', 'squeezellm'} else None
dtype = torch.float16 if quantise == 'gptq' else llm._torch_dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
try:
self.model = vllm.AsyncLLMEngine.from_engine_args(
vllm.AsyncEngineArgs(
worker_use_ray=False,
engine_use_ray=False, #
tokenizer_mode='auto',
tensor_parallel_size=num_gpus, #
model=llm.bentomodel.path,
tokenizer=llm.bentomodel.path, #
trust_remote_code=llm.trust_remote_code,
dtype=dtype, #
max_model_len=llm._max_model_len,
gpu_memory_utilization=llm._gpu_memory_utilization, #
quantization=quantise,
)
)
except Exception as err:
traceback.print_exc()
raise openllm.exceptions.OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from err
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
_, sampling_params = self.config.model_construct_env(stop=stop, **attrs).inference_options(self.llm)
async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids):
out = GenerationOutput.from_vllm(request_output).model_dump_json()
out = bentoml.io.SSE(out).marshal()
yield out
@registry(alias='pt')
class PyTorchRunnable(bentoml.Runnable):
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
SUPPORTS_CPU_MULTI_THREADING = True
def __init__(self, llm):
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
self.is_encoder_decoder = llm.model.config.is_encoder_decoder
if hasattr(llm.model, 'device'):
self.device = llm.model.device
else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
from ._generation import get_context_length, prepare_logits_processor
if adapter_name is not None:
self.model.set_adapter(adapter_name)
max_new_tokens = attrs.pop('max_new_tokens', 256)
context_length = attrs.pop('context_length', None)
if context_length is None:
context_length = get_context_length(self.model.config)
if self.model.config.is_encoder_decoder:
max_src_len = context_length
else:
max_src_len = context_length - max_new_tokens - 1
prompt_token_ids = prompt_token_ids[-max_src_len:]
stop_token_ids = [self.tokenizer.encode(it) for it in stop]
if self.tokenizer.eos_token_id not in stop_token_ids: # add eos token
stop_token_ids.append(self.tokenizer.eos_token_id)
config = self.config.model_construct_env(max_new_tokens=max_new_tokens, **attrs)
logits_processor = prepare_logits_processor(config)
cumulative_logprob = 0.0
with torch.inference_mode():
output_token_ids = list(prompt_token_ids)
input_len = len(prompt_token_ids)
if self.is_encoder_decoder:
if config['logprobs']: # FIXME: logprobs is not supported
raise NotImplementedError('Logprobs is yet to be supported with encoder-decoder models.')
encoder_output = self.model.encoder(input_ids=torch.as_tensor([prompt_token_ids], device=self.device))[0]
start_ids = torch.as_tensor([[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device)
else:
start_ids = torch.as_tensor([prompt_token_ids], device=self.device)
past_key_values = out = token = None
finish_reason = None
prompt_logprobs = []
prompt_token_indices = []
stopped = False
sample_logprobs: SampleLogprobs = [None] # The first token has no logprobs
for i in range(config['max_new_tokens']):
if i == 0: # prefill
if self.is_encoder_decoder:
out = self.model.decoder(input_ids=start_ids, encoder_hidden_states=encoder_output, use_cache=True)
logits = self.model.lm_head(out[0])
else:
out = self.model(input_ids=start_ids, use_cache=True)
logits = out.logits
elif self.is_encoder_decoder: # decoding
out = self.model.decoder(
input_ids=torch.as_tensor([[token]], device=self.device),
encoder_hidden_states=encoder_output,
past_key_values=past_key_values,
use_cache=True,
)
logits = self.model.lm_head(out[0])
else:
out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
if config['repetition_penalty'] > 1.0:
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
else:
last_token_logits = logits[0, -1, :]
# Switch to CPU by avoiding some bugs in mps backend.
if self.device.type == 'mps':
last_token_logits = last_token_logits.float().to('cpu')
# TODO: refactor for better sampling logic and apply penalties correctly
# support sequence generation, best_of
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy
_, indices = torch.topk(last_token_logits, 2)
else:
probs = torch.softmax(last_token_logits, dim=-1, dtype=torch.float)
indices = torch.multinomial(probs, num_samples=2)
tokens = [int(token) for token in indices.tolist()]
token = tokens[0]
output_token_ids.append(token)
if config['logprobs']:
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
token_logprobs = logprobs[token].item()
cumulative_logprob += token_logprobs
if config['prompt_logprobs']:
for token_id in prompt_token_ids:
if token_id in prompt_token_indices:
continue
prompt_token_indices.append(token_id)
prompt_logprobs.append({token_id: logprobs[token_id].item()})
stopped = token in stop_token_ids
tmp_output_ids, rfind_start = output_token_ids[input_len:], 0
# XXX: Move this to API server
text = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)
if len(stop) > 0:
for it in stop:
pos = text.rfind(it, rfind_start)
if pos != -1:
text, stopped = text[:pos], True
break
if config['logprobs']:
sample_logprobs.append({token: token_logprobs})
yield GenerationOutput(
prompt='',
finished=False,
outputs=[
CompletionChunk(
index=0,
text=text,
token_ids=tmp_output_ids,
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=None,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
request_id=request_id,
).model_dump_json()
if stopped:
break
else:
finish_reason = 'length'
if stopped:
finish_reason = 'stop'
yield GenerationOutput(
prompt='',
finished=True,
outputs=[
CompletionChunk(
index=0,
text=text,
token_ids=output_token_ids,
cumulative_logprob=cumulative_logprob,
logprobs=sample_logprobs if config['logprobs'] else None,
finish_reason=finish_reason,
)
],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
request_id=request_id,
).model_dump_json()
# Clean
del past_key_values, out
gc.collect()
torch.cuda.empty_cache()