mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-17 03:47:54 -05:00
334 lines
13 KiB
Python
334 lines
13 KiB
Python
from __future__ import annotations
|
|
import gc, traceback, types, typing as t
|
|
import torch, bentoml, openllm
|
|
from openllm_core._schemas import CompletionChunk, GenerationOutput, SampleLogprobs
|
|
from openllm_core.utils import ReprMixin, is_ctranslate_available, is_vllm_available
|
|
|
|
if t.TYPE_CHECKING:
|
|
from openllm_core._typing_compat import M, T
|
|
from ._runners import Runner
|
|
|
|
_registry = {}
|
|
__all__ = ['runner']
|
|
|
|
|
|
def registry(cls=None, *, alias=None):
|
|
def decorator(_cls):
|
|
_registry[_cls.__name__[:-8].lower() if alias is None else alias] = _cls
|
|
return _cls
|
|
|
|
if cls is None:
|
|
return decorator
|
|
return decorator(cls)
|
|
|
|
|
|
def runner(llm: openllm.LLM[M, T]) -> Runner[M, T]:
|
|
try:
|
|
assert llm.bentomodel
|
|
except (bentoml.exceptions.NotFound, AssertionError) as err:
|
|
raise RuntimeError(f'Failed to locate {llm.bentomodel}: {err}') from err
|
|
|
|
return types.new_class(
|
|
llm.config.__class__.__name__[:-6] + 'Runner',
|
|
(bentoml.Runner,), #
|
|
exec_body=lambda ns: ns.update({
|
|
'llm_type': llm.llm_type,
|
|
'identifying_params': llm.identifying_params, #
|
|
'llm_tag': llm.tag,
|
|
'llm': llm,
|
|
'config': llm.config,
|
|
'backend': llm.__llm_backend__, #
|
|
'__module__': llm.__module__,
|
|
'__repr__': ReprMixin.__repr__, #
|
|
'__doc__': llm.config.__class__.__doc__ or f'Generated Runner class for {llm.config["model_name"]}',
|
|
'__repr_keys__': property(lambda _: {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}),
|
|
'__repr_args__': lambda _: (
|
|
(
|
|
'runner_methods',
|
|
{
|
|
method.name: {'batchable': method.config.batchable, 'batch_dim': method.config.batch_dim if method.config.batchable else None}
|
|
for method in _.runner_methods
|
|
},
|
|
),
|
|
('config', llm.config.model_dump(flatten=True)),
|
|
('llm_type', llm.llm_type),
|
|
('backend', llm.__llm_backend__),
|
|
('llm_tag', llm.tag),
|
|
),
|
|
'has_adapters': llm.has_adapters,
|
|
'template': llm.config.template,
|
|
'system_message': llm.config.system_message,
|
|
}),
|
|
)(
|
|
_registry[llm.__llm_backend__],
|
|
name=f"llm-{llm.config['start_name']}-runner",
|
|
models=[llm.bentomodel],
|
|
scheduling_strategy=openllm.CascadingResourceStrategy,
|
|
runnable_init_params={'llm': llm},
|
|
)
|
|
|
|
|
|
@registry
|
|
class CTranslateRunnable(bentoml.Runnable):
|
|
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu')
|
|
SUPPORTS_CPU_MULTI_THREADING = True
|
|
|
|
def __init__(self, llm):
|
|
if not is_ctranslate_available():
|
|
raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
|
|
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
|
|
|
@bentoml.Runnable.method(batchable=False)
|
|
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
|
config, sampling_params = self.config.model_construct_env(stop=list(stop), **attrs).inference_options(self.llm)
|
|
cumulative_logprob, output_token_ids, input_len = 0.0, list(prompt_token_ids), len(prompt_token_ids)
|
|
tokens = self.tokenizer.convert_ids_to_tokens(prompt_token_ids)
|
|
async for request_output in self.model.async_generate_tokens(tokens, **sampling_params):
|
|
if config['logprobs']:
|
|
cumulative_logprob += request_output.log_prob
|
|
output_token_ids.append(request_output.token_id)
|
|
text = self.tokenizer.decode(
|
|
output_token_ids[input_len:],
|
|
skip_special_tokens=True, #
|
|
spaces_between_special_tokens=False,
|
|
clean_up_tokenization_spaces=True, #
|
|
)
|
|
yield GenerationOutput(
|
|
prompt_token_ids=prompt_token_ids, #
|
|
prompt='',
|
|
finished=request_output.is_last,
|
|
request_id=request_id, #
|
|
outputs=[
|
|
CompletionChunk(
|
|
index=0,
|
|
text=text,
|
|
finish_reason=None, #
|
|
token_ids=output_token_ids[input_len:],
|
|
cumulative_logprob=cumulative_logprob, #
|
|
# TODO: logprobs, but seems like we don't have access to the raw logits
|
|
)
|
|
],
|
|
).model_dump_json()
|
|
|
|
|
|
@registry
|
|
class vLLMRunnable(bentoml.Runnable):
|
|
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
|
SUPPORTS_CPU_MULTI_THREADING = True
|
|
|
|
def __init__(self, llm):
|
|
if not is_vllm_available():
|
|
raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
|
|
import vllm
|
|
|
|
self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer
|
|
num_gpus, dev = 1, openllm.utils.device_count()
|
|
if dev >= 2:
|
|
num_gpus = min(dev // 2 * 2, dev)
|
|
quantise = llm.quantise if llm.quantise and llm.quantise in {'gptq', 'awq', 'squeezellm'} else None
|
|
dtype = torch.float16 if quantise == 'gptq' else llm._torch_dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
|
|
try:
|
|
self.model = vllm.AsyncLLMEngine.from_engine_args(
|
|
vllm.AsyncEngineArgs(
|
|
worker_use_ray=False,
|
|
engine_use_ray=False, #
|
|
tokenizer_mode='auto',
|
|
tensor_parallel_size=num_gpus, #
|
|
model=llm.bentomodel.path,
|
|
tokenizer=llm.bentomodel.path, #
|
|
trust_remote_code=llm.trust_remote_code,
|
|
dtype=dtype, #
|
|
max_model_len=llm._max_model_len,
|
|
gpu_memory_utilization=llm._gpu_memory_utilization, #
|
|
quantization=quantise,
|
|
)
|
|
)
|
|
except Exception as err:
|
|
traceback.print_exc()
|
|
raise openllm.exceptions.OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from err
|
|
|
|
@bentoml.Runnable.method(batchable=False)
|
|
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
|
_, sampling_params = self.config.model_construct_env(stop=stop, **attrs).inference_options(self.llm)
|
|
async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids):
|
|
out = GenerationOutput.from_vllm(request_output).model_dump_json()
|
|
out = bentoml.io.SSE(out).marshal()
|
|
yield out
|
|
|
|
|
|
@registry(alias='pt')
|
|
class PyTorchRunnable(bentoml.Runnable):
|
|
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
|
SUPPORTS_CPU_MULTI_THREADING = True
|
|
|
|
def __init__(self, llm):
|
|
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
|
self.is_encoder_decoder = llm.model.config.is_encoder_decoder
|
|
if hasattr(llm.model, 'device'):
|
|
self.device = llm.model.device
|
|
else:
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
@bentoml.Runnable.method(batchable=False)
|
|
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
|
from ._generation import get_context_length, prepare_logits_processor
|
|
|
|
if adapter_name is not None:
|
|
self.model.set_adapter(adapter_name)
|
|
|
|
max_new_tokens = attrs.pop('max_new_tokens', 256)
|
|
context_length = attrs.pop('context_length', None)
|
|
if context_length is None:
|
|
context_length = get_context_length(self.model.config)
|
|
if self.model.config.is_encoder_decoder:
|
|
max_src_len = context_length
|
|
else:
|
|
max_src_len = context_length - max_new_tokens - 1
|
|
prompt_token_ids = prompt_token_ids[-max_src_len:]
|
|
|
|
stop_token_ids = [self.tokenizer.encode(it) for it in stop]
|
|
if self.tokenizer.eos_token_id not in stop_token_ids: # add eos token
|
|
stop_token_ids.append(self.tokenizer.eos_token_id)
|
|
|
|
config = self.config.model_construct_env(max_new_tokens=max_new_tokens, **attrs)
|
|
logits_processor = prepare_logits_processor(config)
|
|
cumulative_logprob = 0.0
|
|
|
|
with torch.inference_mode():
|
|
output_token_ids = list(prompt_token_ids)
|
|
input_len = len(prompt_token_ids)
|
|
|
|
if self.is_encoder_decoder:
|
|
if config['logprobs']: # FIXME: logprobs is not supported
|
|
raise NotImplementedError('Logprobs is yet to be supported with encoder-decoder models.')
|
|
encoder_output = self.model.encoder(input_ids=torch.as_tensor([prompt_token_ids], device=self.device))[0]
|
|
start_ids = torch.as_tensor([[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device)
|
|
else:
|
|
start_ids = torch.as_tensor([prompt_token_ids], device=self.device)
|
|
|
|
past_key_values = out = token = None
|
|
finish_reason = None
|
|
prompt_logprobs = []
|
|
prompt_token_indices = []
|
|
stopped = False
|
|
sample_logprobs: SampleLogprobs = [None] # The first token has no logprobs
|
|
|
|
for i in range(config['max_new_tokens']):
|
|
if i == 0: # prefill
|
|
if self.is_encoder_decoder:
|
|
out = self.model.decoder(input_ids=start_ids, encoder_hidden_states=encoder_output, use_cache=True)
|
|
logits = self.model.lm_head(out[0])
|
|
else:
|
|
out = self.model(input_ids=start_ids, use_cache=True)
|
|
logits = out.logits
|
|
elif self.is_encoder_decoder: # decoding
|
|
out = self.model.decoder(
|
|
input_ids=torch.as_tensor([[token]], device=self.device),
|
|
encoder_hidden_states=encoder_output,
|
|
past_key_values=past_key_values,
|
|
use_cache=True,
|
|
)
|
|
logits = self.model.lm_head(out[0])
|
|
else:
|
|
out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True)
|
|
logits = out.logits
|
|
past_key_values = out.past_key_values
|
|
if logits_processor:
|
|
if config['repetition_penalty'] > 1.0:
|
|
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.device)
|
|
else:
|
|
tmp_output_ids = None
|
|
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
|
else:
|
|
last_token_logits = logits[0, -1, :]
|
|
|
|
# Switch to CPU by avoiding some bugs in mps backend.
|
|
if self.device.type == 'mps':
|
|
last_token_logits = last_token_logits.float().to('cpu')
|
|
|
|
# TODO: refactor for better sampling logic and apply penalties correctly
|
|
# support sequence generation, best_of
|
|
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy
|
|
_, indices = torch.topk(last_token_logits, 2)
|
|
else:
|
|
probs = torch.softmax(last_token_logits, dim=-1, dtype=torch.float)
|
|
indices = torch.multinomial(probs, num_samples=2)
|
|
tokens = [int(token) for token in indices.tolist()]
|
|
|
|
token = tokens[0]
|
|
output_token_ids.append(token)
|
|
if config['logprobs']:
|
|
# NOTE: We can't use last_token_logits since logprobs is based on raw logits
|
|
logprobs = torch.log_softmax(logits[0, -1, :], dim=-1, dtype=torch.float)
|
|
token_logprobs = logprobs[token].item()
|
|
cumulative_logprob += token_logprobs
|
|
|
|
if config['prompt_logprobs']:
|
|
for token_id in prompt_token_ids:
|
|
if token_id in prompt_token_indices:
|
|
continue
|
|
prompt_token_indices.append(token_id)
|
|
prompt_logprobs.append({token_id: logprobs[token_id].item()})
|
|
|
|
stopped = token in stop_token_ids
|
|
|
|
tmp_output_ids, rfind_start = output_token_ids[input_len:], 0
|
|
# XXX: Move this to API server
|
|
text = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)
|
|
|
|
if len(stop) > 0:
|
|
for it in stop:
|
|
pos = text.rfind(it, rfind_start)
|
|
if pos != -1:
|
|
text, stopped = text[:pos], True
|
|
break
|
|
|
|
if config['logprobs']:
|
|
sample_logprobs.append({token: token_logprobs})
|
|
|
|
yield GenerationOutput(
|
|
prompt='',
|
|
finished=False,
|
|
outputs=[
|
|
CompletionChunk(
|
|
index=0,
|
|
text=text,
|
|
token_ids=tmp_output_ids,
|
|
cumulative_logprob=cumulative_logprob,
|
|
logprobs=sample_logprobs if config['logprobs'] else None,
|
|
finish_reason=None,
|
|
)
|
|
],
|
|
prompt_token_ids=prompt_token_ids,
|
|
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
|
|
request_id=request_id,
|
|
).model_dump_json()
|
|
if stopped:
|
|
break
|
|
else:
|
|
finish_reason = 'length'
|
|
if stopped:
|
|
finish_reason = 'stop'
|
|
yield GenerationOutput(
|
|
prompt='',
|
|
finished=True,
|
|
outputs=[
|
|
CompletionChunk(
|
|
index=0,
|
|
text=text,
|
|
token_ids=output_token_ids,
|
|
cumulative_logprob=cumulative_logprob,
|
|
logprobs=sample_logprobs if config['logprobs'] else None,
|
|
finish_reason=finish_reason,
|
|
)
|
|
],
|
|
prompt_token_ids=prompt_token_ids,
|
|
prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None,
|
|
request_id=request_id,
|
|
).model_dump_json()
|
|
|
|
# Clean
|
|
del past_key_values, out
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|