feat(vllm): streaming (#260)

This commit is contained in:
Aaron Pham
2023-08-26 07:27:32 -04:00
committed by GitHub
parent 63a27c8c41
commit 938fd362bb
7 changed files with 108 additions and 43 deletions

View File

@@ -11,7 +11,6 @@ import re
import traceback
import types
import typing as t
import uuid
import attr
import fs.path
@@ -347,8 +346,8 @@ def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]
# assign vllm specific implementation
if cls.__llm_implementation__ == 'vllm':
globs.update({'_vllm_generate': vllm_generate, '_vllm_postprocess_generate': vllm_postprocess_generate})
lines.extend([_setattr_class(it, f'_vllm_{it}') for it in {'generate', 'postprocess_generate'}])
globs.update({'_vllm_generate': vllm_generate, '_vllm_postprocess_generate': vllm_postprocess_generate, '_vllm_generate_iterator': vllm_generate_iterator})
lines.extend([_setattr_class(it, f'_vllm_{it}') for it in {'generate', 'postprocess_generate', 'generate_iterator'}])
# cached attribute initialisation
interface_anns = codegen.get_annotations(LLMInterface)
@@ -364,10 +363,36 @@ def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]
return codegen.generate_function(cls, '__assign_llm_attr', lines, args=('cls', *args), globs=globs, annotations=anns)
def vllm_postprocess_generate(self: LLM['vllm.LLMEngine', T], prompt: str, generation_result: list[dict[str, t.Any]], **_: t.Any) -> str:
return generation_result[0]['outputs'][0]['text']
def vllm_generate_iterator(
self: LLM['vllm.LLMEngine', T], prompt: str, /, *, echo: bool = False, stop: str | t.Iterable[str] | None = None, stop_token_ids: list[int] | None = None, **attrs: t.Any
) -> t.Iterator[dict[str, t.Any]]:
request_id: str = attrs.pop('request_id', None)
if request_id is None: raise ValueError('request_id must not be None.')
if stop_token_ids is None: stop_token_ids = []
stop_token_ids.append(self.tokenizer.eos_token_id)
stop_ = set()
if isinstance(stop, str) and stop != '': stop_.add(stop)
elif isinstance(stop, list) and stop != []: stop_.update(stop)
for tid in stop_token_ids:
if tid: stop_.add(self.tokenizer.decode(tid))
if self.config['temperature'] <= 1e-5: top_p = 1.0
else: top_p = self.config['top_p']
config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs)
self.model.add_request(request_id=request_id, prompt=prompt, sampling_params=config.to_sampling_config())
while self.model.has_unfinished_requests():
for request_output in self.model.step():
prompt = request_output.prompt
if echo: text_outputs = [prompt + output.text for output in request_output.outputs]
else: text_outputs = [output.text for output in request_output.outputs]
yield {'text': text_outputs, 'error_code': 0}
if request_output.finished: break
def vllm_generate(self: LLM['vllm.LLMEngine', T], prompt: str, **attrs: t.Any) -> list[dict[str, t.Any]]:
request_id: str = attrs.pop('request_id', None)
if request_id is None: raise ValueError('request_id must not be None.')
outputs: list[vllm.RequestOutput] = []
# TODO: support prompt_token_ids
self.model.add_request(request_id=str(uuid.uuid4().hex), prompt=prompt, sampling_params=self.config.model_construct_env(**attrs).to_sampling_config())
self.model.add_request(request_id=request_id, prompt=prompt, sampling_params=self.config.model_construct_env(**attrs).to_sampling_config())
while self.model.has_unfinished_requests():
outputs.extend([r for r in self.model.step() if r.finished])
return [unmarshal_vllm_outputs(i) for i in outputs]
@@ -1018,14 +1043,13 @@ class LLM(LLMInterface[M, T], ReprMixin):
/,
*,
context_length: int | None = None,
echo: bool = True,
echo: bool = False,
stream_interval: int = 2,
stop: str | t.Iterable[str] | None = None,
stop_token_ids: list[int] | None = None,
**attrs: t.Any
) -> t.Iterator[t.Any]:
# NOTE: encoder-decoder models will need to implement their own generate_iterator for now
# inspired from fastchat's generate_stream_func
from ._generation import get_context_length, is_partial_stop, prepare_logits_processor
len_prompt = len(prompt)
@@ -1045,29 +1069,44 @@ class LLM(LLMInterface[M, T], ReprMixin):
past_key_values = out = token = None
for i in range(self.config['max_new_tokens']):
torch.cuda.synchronize()
if i == 0: # prefill
out = self.model(torch.as_tensor([input_ids], device=self.device), use_cache=True)
out = self.model(torch.as_tensor([input_ids]), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else: # decoding
out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
out = self.model(input_ids=torch.as_tensor([[token]]), use_cache=True, past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
if self.config['repetition_penalty'] > 1.0: tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=logits.device)
else: tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
else:
last_token_logits = logits[0, -1, :]
last_token_logits = logits_processor(torch.as_tensor([output_ids], device=logits.device)
if self.config['repetition_penalty'] > 1.0 else None, logits[:, -1, :])[0] if logits_processor else logits[0, -1, :]
# Switch to CPU by avoiding some bugs in mps backend.
if self.device.type == 'mps': last_token_logits = last_token_logits.float().to('cpu')
if self.config['temperature'] < 1e-5 or self.config['top_p'] < 1e-8: token = int(torch.argmax(last_token_logits)) # greedy
else: token = int(torch.multinomial(torch.softmax(last_token_logits, dim=-1), num_samples=1))
else:
probs = torch.softmax(last_token_logits, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
torch.cuda.synchronize()
if token in stop_token_ids: stopped = True
else: stopped = False
# Yield the output tokens
if i % stream_interval == 0 or i == self.config['max_new_tokens'] - 1 or stopped:
tmp_output_ids = output_ids if echo else output_ids[input_echo_len:]
rfind_start = len_prompt if echo else 0
if echo:
tmp_output_ids = output_ids
rfind_start = len_prompt
else:
tmp_output_ids = output_ids[input_echo_len:]
rfind_start = 0
output = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)
partially_stopped = False
@@ -1097,7 +1136,6 @@ class LLM(LLMInterface[M, T], ReprMixin):
elif stopped: finish_reason = 'stop'
else: finish_reason = None
yield {'text': output, 'usage': {'prompt_tokens': input_echo_len, 'completion_tokens': i, 'total_tokens': input_echo_len + i}, 'finish_reason': finish_reason}
# Clean
del past_key_values, out
gc.collect()
@@ -1160,6 +1198,7 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
class _Runnable(bentoml.Runnable):
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
SUPPORTS_CPU_MULTI_THREADING = True
framework = self.__llm_implementation__
def __init__(__self: _Runnable):
# NOTE: The side effect of this line
@@ -1190,6 +1229,7 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
def generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
if __self.framework =='vllm': attrs.setdefault('request_id', openllm_core.utils.gen_random_uuid())
return self.generate(prompt, **attrs)
@bentoml.Runnable.method(**method_signature(generate_sig))
@@ -1203,14 +1243,15 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
pre = 0
for outputs in self.generate_iterator(prompt, **attrs):
output_text = outputs['text'].strip().split(' ')
for outputs in self.generate_iterator(prompt, request_id=openllm_core.utils.gen_random_uuid(), **attrs):
output_text = outputs['text'][0] if __self.framework == 'vllm' else outputs['text']
output_text = output_text.strip().split(' ')
now = len(output_text) - 1
if now > pre:
yield ' '.join(output_text[pre:now])
yield ' '.join(output_text[pre:now]) + ' '
pre = now
yield ' '.join(output_text[pre:])
return ' '.join(output_text)
yield ' '.join(output_text[pre:]) + ' '
return ' '.join(output_text) + ' '
return types.new_class(self.__class__.__name__ + 'Runnable', (_Runnable,), {}, lambda ns: ns.update({'SUPPORTED_RESOURCES': ('nvidia.com/gpu', 'amd.com/gpu') if self.config['requires_gpu'] else ('nvidia.com/gpu', 'amd.com/gpu', 'cpu'), '__module__': self.__module__, '__doc__': self.config['env'].start_docstring}))
def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]:
@@ -1256,11 +1297,12 @@ def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]:
yield 'llm_type', __self.llm_type
yield 'runtime', self.runtime
yield 'llm_tag', self.tag
yield 'llm_framework', self.__llm_implementation__
return types.new_class(
self.__class__.__name__ + 'Runner', (bentoml.Runner,),
exec_body=lambda ns: ns.update({
'llm_type': self.llm_type, 'identifying_params': self.identifying_params, 'llm_tag': self.tag, 'llm': self, 'config': self.config, 'implementation': self.__llm_implementation__, 'peft_adapters': property(fget=available_adapters), 'download_model': self.ensure_model_id_exists, '__call__': _wrapped_generate_run, 'embed': _wrapped_embeddings_run, '__module__': self.__module__, '__doc__': self.config['env'].start_docstring, '__repr__': ReprMixin.__repr__, '__repr_keys__': property(_wrapped_repr_keys), '__repr_args__': _wrapped_repr_args, 'supports_embeddings': self['supports_embeddings'], 'supports_hf_agent': self['supports_generate_one'], 'has_adapters': self._adapters_mapping is not None
'llm_type': self.llm_type, 'identifying_params': self.identifying_params, 'llm_framework': self.__llm_implementation__, 'llm_tag': self.tag, 'llm': self, 'config': self.config, 'implementation': self.__llm_implementation__, 'peft_adapters': property(fget=available_adapters), 'download_model': self.ensure_model_id_exists, '__call__': _wrapped_generate_run, 'embed': _wrapped_embeddings_run, '__module__': self.__module__, '__doc__': self.config['env'].start_docstring, '__repr__': ReprMixin.__repr__, '__repr_keys__': property(_wrapped_repr_keys), '__repr_args__': _wrapped_repr_args, 'supports_embeddings': self['supports_embeddings'], 'supports_hf_agent': self['supports_generate_one'], 'has_adapters': self._adapters_mapping is not None
})
)
__all__ = ['LLMRunner', 'LLMRunnable', 'Runner', 'LLM', 'llm_runner_class', 'llm_runnable_class', 'LLMEmbeddings']

View File

@@ -44,10 +44,11 @@ async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
config = qa_inputs.llm_config.model_dump()
responses = await runner.generate.async_run(qa_inputs.prompt, **{'adapter_name': qa_inputs.adapter_name, **config})
return openllm.GenerationOutput(responses=responses, configuration=config)
@svc.api(route='/v1/generate_stream', input=_JsonInput, output=bentoml.io.Text(content_type='text/event_stream'))
@svc.api(route='/v1/generate_stream', input=_JsonInput, output=bentoml.io.Text(content_type='text/event-stream'))
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:
echo = input_dict.pop('echo', False)
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
return runner.generate_iterator.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, **qa_inputs.llm_config.model_dump())
return runner.generate_iterator.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, **qa_inputs.llm_config.model_dump())
@svc.api(
route='/v1/metadata',
input=bentoml.io.Text(),
@@ -55,7 +56,7 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s
'model_id': runner.llm.model_id,
'timeout': 3600,
'model_name': llm_config['model_name'],
'framework': 'pt',
'framework': runner.llm_framework,
'configuration': '',
'supports_embeddings': runner.supports_embeddings,
'supports_hf_agent': runner.supports_hf_agent
@@ -126,6 +127,7 @@ if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
hf_app = Starlette(debug=True, routes=[Route('/agent', hf_agent, methods=['POST'])])
svc.mount_asgi_app(hf_app, path='/hf')
# general metadata app
async def list_adapter_v1(_: Request) -> Response:
res: dict[str, t.Any] = {}
if runner.peft_adapters['success'] is True: res['result'] = {k: v.to_dict() for k, v in runner.peft_adapters['result'].items()}