fix(openai): correct stop tokens and finish_reason state (#722)

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-22 04:21:13 -05:00
committed by GitHub
parent 06626e7d1e
commit 63d86faa32
7 changed files with 103 additions and 64 deletions

View File

@@ -118,7 +118,7 @@ class LLM(t.Generic[M, T], ReprMixin):
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
try:
generator = self.runner.generate_iterator.async_stream(
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
prompt_token_ids, request_id, stop=list(stop), adapter_name=adapter_name, **config.model_dump(flatten=True)
)
except Exception as err:
raise RuntimeError(f'Failed to start generation task: {err}') from err
@@ -127,8 +127,6 @@ class LLM(t.Generic[M, T], ReprMixin):
async for out in generator:
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
delta_outputs = [None] * len(generated.outputs)
if generated.finished:
break
for output in generated.outputs:
i = output.index
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]

View File

@@ -89,13 +89,7 @@ class CTranslateRunnable(bentoml.Runnable):
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
stop_ = set()
if isinstance(stop, str) and stop != '':
stop_.add(stop)
elif isinstance(stop, t.Iterable):
stop_.update(stop)
config, sampling_params = self.config.model_construct_env(stop=list(stop_), **attrs).inference_options(self.llm)
config, sampling_params = self.config.model_construct_env(stop=list(stop), **attrs).inference_options(self.llm)
cumulative_logprob, output_token_ids, input_len = 0.0, list(prompt_token_ids), len(prompt_token_ids)
tokens = self.tokenizer.convert_ids_to_tokens(prompt_token_ids)
@@ -165,18 +159,9 @@ class vLLMRunnable(bentoml.Runnable):
@bentoml.Runnable.method(batchable=False)
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
stop_ = set()
if isinstance(stop, str) and stop != '':
stop_.add(stop)
elif isinstance(stop, t.Iterable):
stop_.update(stop)
config, sampling_params = self.config.model_construct_env(stop=list(stop_), **attrs).inference_options(self.llm)
config, sampling_params = self.config.model_construct_env(stop=stop, **attrs).inference_options(self.llm)
async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids):
# XXX: Need to write a hook for serialisation None correctly
if request_output.prompt_logprobs is not None:
request_output.prompt_logprobs = [it if it else {} for it in request_output.prompt_logprobs]
yield GenerationOutput.from_vllm(request_output).model_dump_json()
@@ -197,14 +182,7 @@ class PyTorchRunnable(bentoml.Runnable):
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
if adapter_name is not None:
self.model.set_adapter(adapter_name)
stop_ = set()
if isinstance(stop, str) and stop != '':
stop_.add(stop)
elif isinstance(stop, t.Iterable):
stop_.update(stop)
async for generation_output in self.forward(prompt_token_ids, request_id, list(stop_), **attrs):
async for generation_output in self.forward(prompt_token_ids, request_id, list(stop), **attrs):
yield generation_output.model_dump_json()
async def forward(self, prompt_token_ids, request_id, stop, **attrs):

View File

@@ -10,7 +10,6 @@ from starlette.applications import Starlette
from starlette.responses import JSONResponse, StreamingResponse
from starlette.routing import Route
from openllm_core._schemas import SampleLogprobs
from openllm_core.utils import converter, gen_random_uuid
from ._openapi import add_schema_definitions, append_schemas, apply_schema, get_generator
@@ -165,34 +164,38 @@ async def chat_completions(req, llm):
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def create_stream_response_json(index, text, finish_reason=None):
return jsonify_attr(
ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
],
)
def create_stream_response_json(index, text, finish_reason=None, usage=None):
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
],
)
if usage is not None: response.usage = usage
return jsonify_attr(response)
async def completion_stream_generator():
async def chat_completion_stream_generator():
# first chunk with role
for i in range(config['n']):
yield f"data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role='assistant'), finish_reason=None)], model=model_name))}\n\n"
previous_num_tokens = [0] * config['n']
async for res in result_generator:
for output in res.outputs:
yield f'data: {create_stream_response_json(output.index, output.text)}\n\n'
previous_num_tokens[output.index] += len(output.token_ids)
if output.finish_reason is not None:
yield f'data: {create_stream_response_json(output.index, "", output.finish_reason)}\n\n'
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
yield f'data: {create_stream_response_json(output.index, "", output.finish_reason, usage)}\n\n'
yield 'data: [DONE]\n\n'
try:
# Streaming case
if request.stream:
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result = None
texts, token_ids = [[]] * config['n'], [[]] * config['n']
@@ -287,34 +290,38 @@ async def completions(req, llm):
# TODO: support use_beam_search
stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of'])
def create_stream_response_json(index, text, logprobs=None, finish_reason=None):
return jsonify_attr(
CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)
],
)
def create_stream_response_json(index, text, logprobs=None, finish_reason=None, usage=None):
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)
],
)
if usage: response.usage = usage
return jsonify_attr(response)
async def completion_stream_generator():
previous_num_tokens = [0] * config['n']
previous_texts = [''] * config['n']
async for res in result_generator:
for output in res.outputs:
i = output.index
if request.logprobs is not None:
logprobs = create_logprobs(
token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i] :], llm=llm
token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i]:], initial_text_offset=len(previous_texts[i]), llm=llm
)
else:
logprobs = None
previous_num_tokens[i] += len(output.token_ids)
previous_texts[i] += output.text
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs)}\n\n'
if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None
yield f'data: {create_stream_response_json(index=i, text="", logprobs=logprobs, finish_reason=output.finish_reason)}\n\n'
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
yield f'data: {create_stream_response_json(i, "", logprobs, output.finish_reason, usage)}\n\n'
yield 'data: [DONE]\n\n'
try:
@@ -344,7 +351,7 @@ async def completions(req, llm):
for output in final_result.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(
token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs), llm=llm
token_ids=output.token_ids, id_logprobs=output.logprobs, llm=llm
)
else:
logprobs = None

View File

@@ -5,6 +5,7 @@ import typing as t
import attr
import openllm_core
from openllm_core._schemas import FinishReason
from openllm_core.utils import converter
@@ -16,6 +17,9 @@ class ErrorResponse:
param: t.Optional[str] = None
code: t.Optional[str] = None
def _stop_converter(data: t.Union[str, t.List[str]]) -> t.List[str]:
if not data: return None
return [data] if isinstance(data, str) else data
@attr.define
class CompletionRequest:
@@ -29,7 +33,7 @@ class CompletionRequest:
stream: t.Optional[bool] = attr.field(default=False)
logprobs: t.Optional[int] = attr.field(default=None)
echo: t.Optional[bool] = attr.field(default=False)
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None, converter=_stop_converter)
presence_penalty: t.Optional[float] = attr.field(default=0.0)
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
@@ -49,7 +53,7 @@ class ChatCompletionRequest:
top_p: t.Optional[float] = attr.field(default=None)
n: t.Optional[int] = attr.field(default=None)
stream: t.Optional[bool] = attr.field(default=False)
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None, converter=_stop_converter)
max_tokens: t.Optional[int] = attr.field(default=None)
presence_penalty: t.Optional[float] = attr.field(default=None)
frequency_penalty: t.Optional[float] = attr.field(default=None)
@@ -80,7 +84,7 @@ class CompletionResponseChoice:
index: int
text: str
logprobs: t.Optional[LogProbs] = None
finish_reason: t.Optional[str] = None
finish_reason: t.Optional[FinishReason] = None
@attr.define
@@ -88,7 +92,7 @@ class CompletionResponseStreamChoice:
index: int
text: str
logprobs: t.Optional[LogProbs] = None
finish_reason: t.Optional[str] = None
finish_reason: t.Optional[FinishReason] = None
@attr.define
@@ -98,6 +102,7 @@ class CompletionStreamResponse:
object: str = 'text_completion'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
usage: t.Optional[UsageInfo] = attr.field(default=None)
@attr.define
@@ -132,14 +137,14 @@ converter.register_unstructure_hook(ChatMessage, lambda msg: {'role': msg.role,
class ChatCompletionResponseStreamChoice:
index: int
delta: Delta
finish_reason: t.Optional[str] = attr.field(default=None)
finish_reason: t.Optional[FinishReason] = None
@attr.define
class ChatCompletionResponseChoice:
index: int
message: ChatMessage
finish_reason: t.Optional[str] = attr.field(default=None)
finish_reason: t.Optional[FinishReason] = None
@attr.define
@@ -159,6 +164,7 @@ class ChatCompletionStreamResponse:
object: str = 'chat.completion.chunk'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
usage: t.Optional[UsageInfo] = attr.field(default=None)
@attr.define