mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-20 05:18:50 -05:00
249 lines
8.4 KiB
Python
249 lines
8.4 KiB
Python
from __future__ import annotations
|
|
import gc
|
|
import os
|
|
import traceback
|
|
import typing as t
|
|
|
|
import torch
|
|
|
|
import bentoml
|
|
import openllm
|
|
|
|
from openllm.exceptions import OpenLLMException
|
|
from openllm_core._schemas import CompletionChunk
|
|
from openllm_core._schemas import GenerationOutput
|
|
from openllm_core._typing_compat import LiteralBackend
|
|
from openllm_core._typing_compat import M
|
|
from openllm_core._typing_compat import T
|
|
from openllm_core.utils import first_not_none
|
|
from openllm_core.utils import get_debug_mode
|
|
from openllm_core.utils import is_vllm_available
|
|
|
|
|
|
if t.TYPE_CHECKING:
|
|
import vllm
|
|
|
|
from openllm_core._schemas import FinishReason
|
|
else:
|
|
vllm = openllm.utils.LazyLoader('vllm', globals(), 'vllm')
|
|
|
|
_DEFAULT_TOKENIZER = 'hf-internal-testing/llama-tokenizer'
|
|
|
|
__all__ = ['runnable']
|
|
|
|
|
|
def runnable(backend: LiteralBackend | None = None) -> type[bentoml.Runnable]:
|
|
backend = t.cast(
|
|
LiteralBackend,
|
|
first_not_none(backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if is_vllm_available() else 'pt'),
|
|
)
|
|
return vLLMRunnable if backend == 'vllm' else PyTorchRunnable
|
|
|
|
|
|
class vLLMRunnable(bentoml.Runnable):
|
|
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
|
SUPPORTS_CPU_MULTI_THREADING = True
|
|
|
|
def __init__(self, llm: openllm.LLM[M, T]) -> None:
|
|
self.config = llm.config
|
|
num_gpus, dev = 1, openllm.utils.device_count()
|
|
if dev >= 2:
|
|
num_gpus = min(dev // 2 * 2, dev)
|
|
quantization = None
|
|
if llm._quantise and llm._quantise in {'awq', 'squeezellm'}:
|
|
quantization = llm._quantise
|
|
try:
|
|
self.model = vllm.AsyncLLMEngine.from_engine_args(
|
|
vllm.AsyncEngineArgs(
|
|
model=llm.bentomodel.path,
|
|
tokenizer=llm.bentomodel.path,
|
|
trust_remote_code=llm.trust_remote_code,
|
|
tokenizer_mode='auto',
|
|
tensor_parallel_size=num_gpus,
|
|
dtype='auto',
|
|
quantization=quantization,
|
|
disable_log_requests=not get_debug_mode(),
|
|
worker_use_ray=False,
|
|
engine_use_ray=False,
|
|
)
|
|
)
|
|
except Exception as err:
|
|
traceback.print_exc()
|
|
raise OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from err
|
|
|
|
@bentoml.Runnable.method(batchable=False)
|
|
async def generate_iterator(
|
|
self,
|
|
prompt_token_ids: list[int],
|
|
request_id: str,
|
|
stop: str | t.Iterable[str] | None = None,
|
|
adapter_name: str | None = None,
|
|
**attrs: t.Any,
|
|
) -> t.AsyncGenerator[str, None]:
|
|
if adapter_name is not None:
|
|
raise NotImplementedError('Adapter is not supported with vLLM.')
|
|
stop_: set[str] = set()
|
|
if isinstance(stop, str) and stop != '':
|
|
stop_.add(stop)
|
|
elif isinstance(stop, t.Iterable):
|
|
stop_.update(stop)
|
|
|
|
temperature = attrs.pop('temperature', self.config['temperature'])
|
|
top_p = attrs.pop('top_p', self.config['top_p'])
|
|
if temperature <= 1e-5:
|
|
top_p = 1.0
|
|
sampling_params = self.config.model_construct_env(
|
|
stop=list(stop_), temperature=temperature, top_p=top_p, **attrs
|
|
).to_sampling_config()
|
|
|
|
async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids):
|
|
# XXX: Need to write a hook for serialisation None correctly
|
|
if request_output.prompt_logprobs is not None:
|
|
request_output.prompt_logprobs = [it if it else {} for it in request_output.prompt_logprobs]
|
|
yield GenerationOutput.from_vllm(request_output).model_dump_json()
|
|
|
|
|
|
class PyTorchRunnable(bentoml.Runnable):
|
|
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
|
SUPPORTS_CPU_MULTI_THREADING = True
|
|
|
|
def __init__(self, llm: openllm.LLM[M, T]) -> None:
|
|
self.model = llm.model
|
|
self.tokenizer = llm.tokenizer
|
|
self.config = llm.config
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
@bentoml.Runnable.method(batchable=False)
|
|
async def generate_iterator(
|
|
self,
|
|
prompt_token_ids: list[int],
|
|
request_id: str,
|
|
stop: str | t.Iterable[str] | None = None,
|
|
adapter_name: str | None = None,
|
|
**attrs: t.Any,
|
|
) -> t.AsyncGenerator[str, None]:
|
|
if adapter_name is not None:
|
|
self.model.set_adapter(adapter_name)
|
|
async for generation_output in self.forward(prompt_token_ids, request_id, stop=stop, **attrs):
|
|
yield generation_output.model_dump_json()
|
|
|
|
async def forward(
|
|
self, prompt_token_ids: list[int], request_id: str, stop: str | t.Iterable[str] | None = None, **attrs: t.Any
|
|
) -> t.AsyncGenerator[GenerationOutput, None]:
|
|
from ._generation import is_partial_stop
|
|
from ._generation import prepare_logits_processor
|
|
|
|
stop_: set[str] = set()
|
|
if isinstance(stop, str) and stop != '':
|
|
stop_.add(stop)
|
|
elif isinstance(stop, t.Iterable):
|
|
stop_.update(stop)
|
|
config = self.config.model_construct_env(**attrs)
|
|
|
|
with torch.inference_mode():
|
|
# TODO: Support context_length check
|
|
# context_length: int | None = attrs.pop('context_length', None)
|
|
# if context_length is None: context_length = get_context_length(self.model.config)
|
|
# max_src_len = context_length - config['max_new_tokens'] - 1
|
|
# prompt_token_ids = prompt_token_ids[-max_src_len:]
|
|
output_token_ids = list(prompt_token_ids)
|
|
input_len = len(prompt_token_ids)
|
|
|
|
logits_processor = prepare_logits_processor(config)
|
|
|
|
past_key_values = out = token = None
|
|
finish_reason: t.Optional[FinishReason] = None
|
|
for i in range(config['max_new_tokens']):
|
|
if i == 0: # prefill
|
|
out = self.model(torch.as_tensor([prompt_token_ids], device=self.device), use_cache=True)
|
|
else: # decoding
|
|
out = self.model(
|
|
torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values
|
|
)
|
|
logits = out.logits
|
|
past_key_values = out.past_key_values
|
|
|
|
if logits_processor:
|
|
if config['repetition_penalty'] > 1.0:
|
|
tmp_output_ids: t.Any = torch.as_tensor([output_token_ids], device=self.device)
|
|
else:
|
|
tmp_output_ids = None
|
|
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
|
else:
|
|
last_token_logits = logits[0, -1, :]
|
|
|
|
# Switch to CPU by avoiding some bugs in mps backend.
|
|
if self.device.type == 'mps':
|
|
last_token_logits = last_token_logits.float().to('cpu')
|
|
|
|
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy
|
|
_, indices = torch.topk(last_token_logits, 2)
|
|
tokens = [int(index) for index in indices.tolist()]
|
|
else:
|
|
probs = torch.softmax(last_token_logits, dim=-1)
|
|
indices = torch.multinomial(probs, num_samples=2)
|
|
tokens = [int(token) for token in indices.tolist()]
|
|
|
|
token = tokens[0]
|
|
output_token_ids.append(token)
|
|
|
|
stopped = False
|
|
|
|
tmp_output_ids, rfind_start = output_token_ids[input_len:], 0
|
|
# XXX: Move this to API server
|
|
text = self.tokenizer.decode(
|
|
tmp_output_ids,
|
|
skip_special_tokens=True,
|
|
spaces_between_special_tokens=False,
|
|
clean_up_tokenization_spaces=True,
|
|
)
|
|
partially_stopped = False
|
|
if stop_:
|
|
for it in stop_:
|
|
pos = text.rfind(it, rfind_start)
|
|
if pos != -1:
|
|
text, stopped = text[:pos], True
|
|
break
|
|
else:
|
|
partially_stopped = is_partial_stop(text, it)
|
|
if partially_stopped:
|
|
break
|
|
if not partially_stopped:
|
|
yield GenerationOutput(
|
|
prompt='',
|
|
finished=False,
|
|
outputs=[
|
|
CompletionChunk(
|
|
index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=None
|
|
)
|
|
],
|
|
prompt_token_ids=prompt_token_ids,
|
|
request_id=request_id,
|
|
)
|
|
if stopped:
|
|
break
|
|
else:
|
|
finish_reason = 'length'
|
|
if stopped:
|
|
finish_reason = 'stop'
|
|
yield GenerationOutput(
|
|
prompt='',
|
|
finished=True,
|
|
outputs=[
|
|
CompletionChunk(
|
|
index=0,
|
|
text=text,
|
|
token_ids=output_token_ids[input_len:],
|
|
cumulative_logprob=0.0,
|
|
finish_reason=finish_reason,
|
|
)
|
|
],
|
|
prompt_token_ids=prompt_token_ids,
|
|
request_id=request_id,
|
|
)
|
|
|
|
# Clean
|
|
del past_key_values, out
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|