mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-18 22:55:08 -05:00
fix(openai): correct stop tokens and finish_reason state (#722)
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -118,7 +118,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
|
||||
try:
|
||||
generator = self.runner.generate_iterator.async_stream(
|
||||
prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True)
|
||||
prompt_token_ids, request_id, stop=list(stop), adapter_name=adapter_name, **config.model_dump(flatten=True)
|
||||
)
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Failed to start generation task: {err}') from err
|
||||
@@ -127,8 +127,6 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
async for out in generator:
|
||||
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
|
||||
delta_outputs = [None] * len(generated.outputs)
|
||||
if generated.finished:
|
||||
break
|
||||
for output in generated.outputs:
|
||||
i = output.index
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
|
||||
|
||||
@@ -89,13 +89,7 @@ class CTranslateRunnable(bentoml.Runnable):
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
stop_ = set()
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
stop_.update(stop)
|
||||
|
||||
config, sampling_params = self.config.model_construct_env(stop=list(stop_), **attrs).inference_options(self.llm)
|
||||
config, sampling_params = self.config.model_construct_env(stop=list(stop), **attrs).inference_options(self.llm)
|
||||
cumulative_logprob, output_token_ids, input_len = 0.0, list(prompt_token_ids), len(prompt_token_ids)
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(prompt_token_ids)
|
||||
|
||||
@@ -165,18 +159,9 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
stop_ = set()
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
stop_.update(stop)
|
||||
|
||||
config, sampling_params = self.config.model_construct_env(stop=list(stop_), **attrs).inference_options(self.llm)
|
||||
config, sampling_params = self.config.model_construct_env(stop=stop, **attrs).inference_options(self.llm)
|
||||
|
||||
async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids):
|
||||
# XXX: Need to write a hook for serialisation None correctly
|
||||
if request_output.prompt_logprobs is not None:
|
||||
request_output.prompt_logprobs = [it if it else {} for it in request_output.prompt_logprobs]
|
||||
yield GenerationOutput.from_vllm(request_output).model_dump_json()
|
||||
|
||||
|
||||
@@ -197,14 +182,7 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
if adapter_name is not None:
|
||||
self.model.set_adapter(adapter_name)
|
||||
|
||||
stop_ = set()
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
stop_.update(stop)
|
||||
|
||||
async for generation_output in self.forward(prompt_token_ids, request_id, list(stop_), **attrs):
|
||||
async for generation_output in self.forward(prompt_token_ids, request_id, list(stop), **attrs):
|
||||
yield generation_output.model_dump_json()
|
||||
|
||||
async def forward(self, prompt_token_ids, request_id, stop, **attrs):
|
||||
|
||||
@@ -10,7 +10,6 @@ from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from openllm_core._schemas import SampleLogprobs
|
||||
from openllm_core.utils import converter, gen_random_uuid
|
||||
|
||||
from ._openapi import add_schema_definitions, append_schemas, apply_schema, get_generator
|
||||
@@ -165,34 +164,38 @@ async def chat_completions(req, llm):
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_response_json(index, text, finish_reason=None):
|
||||
return jsonify_attr(
|
||||
ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
|
||||
],
|
||||
)
|
||||
def create_stream_response_json(index, text, finish_reason=None, usage=None):
|
||||
response = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
|
||||
],
|
||||
)
|
||||
if usage is not None: response.usage = usage
|
||||
return jsonify_attr(response)
|
||||
|
||||
async def completion_stream_generator():
|
||||
async def chat_completion_stream_generator():
|
||||
# first chunk with role
|
||||
for i in range(config['n']):
|
||||
yield f"data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role='assistant'), finish_reason=None)], model=model_name))}\n\n"
|
||||
|
||||
previous_num_tokens = [0] * config['n']
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
yield f'data: {create_stream_response_json(output.index, output.text)}\n\n'
|
||||
previous_num_tokens[output.index] += len(output.token_ids)
|
||||
if output.finish_reason is not None:
|
||||
yield f'data: {create_stream_response_json(output.index, "", output.finish_reason)}\n\n'
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
|
||||
yield f'data: {create_stream_response_json(output.index, "", output.finish_reason, usage)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result = None
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
@@ -287,34 +290,38 @@ async def completions(req, llm):
|
||||
# TODO: support use_beam_search
|
||||
stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of'])
|
||||
|
||||
def create_stream_response_json(index, text, logprobs=None, finish_reason=None):
|
||||
return jsonify_attr(
|
||||
CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)
|
||||
],
|
||||
)
|
||||
def create_stream_response_json(index, text, logprobs=None, finish_reason=None, usage=None):
|
||||
response = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)
|
||||
],
|
||||
)
|
||||
if usage: response.usage = usage
|
||||
return jsonify_attr(response)
|
||||
|
||||
async def completion_stream_generator():
|
||||
previous_num_tokens = [0] * config['n']
|
||||
previous_texts = [''] * config['n']
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i] :], llm=llm
|
||||
token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i]:], initial_text_offset=len(previous_texts[i]), llm=llm
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
previous_texts[i] += output.text
|
||||
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs)}\n\n'
|
||||
if output.finish_reason is not None:
|
||||
logprobs = LogProbs() if request.logprobs is not None else None
|
||||
yield f'data: {create_stream_response_json(index=i, text="", logprobs=logprobs, finish_reason=output.finish_reason)}\n\n'
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
|
||||
yield f'data: {create_stream_response_json(i, "", logprobs, output.finish_reason, usage)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
try:
|
||||
@@ -344,7 +351,7 @@ async def completions(req, llm):
|
||||
for output in final_result.outputs:
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs), llm=llm
|
||||
token_ids=output.token_ids, id_logprobs=output.logprobs, llm=llm
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
@@ -5,6 +5,7 @@ import typing as t
|
||||
import attr
|
||||
|
||||
import openllm_core
|
||||
from openllm_core._schemas import FinishReason
|
||||
from openllm_core.utils import converter
|
||||
|
||||
|
||||
@@ -16,6 +17,9 @@ class ErrorResponse:
|
||||
param: t.Optional[str] = None
|
||||
code: t.Optional[str] = None
|
||||
|
||||
def _stop_converter(data: t.Union[str, t.List[str]]) -> t.List[str]:
|
||||
if not data: return None
|
||||
return [data] if isinstance(data, str) else data
|
||||
|
||||
@attr.define
|
||||
class CompletionRequest:
|
||||
@@ -29,7 +33,7 @@ class CompletionRequest:
|
||||
stream: t.Optional[bool] = attr.field(default=False)
|
||||
logprobs: t.Optional[int] = attr.field(default=None)
|
||||
echo: t.Optional[bool] = attr.field(default=False)
|
||||
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
|
||||
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None, converter=_stop_converter)
|
||||
presence_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
|
||||
@@ -49,7 +53,7 @@ class ChatCompletionRequest:
|
||||
top_p: t.Optional[float] = attr.field(default=None)
|
||||
n: t.Optional[int] = attr.field(default=None)
|
||||
stream: t.Optional[bool] = attr.field(default=False)
|
||||
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
|
||||
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None, converter=_stop_converter)
|
||||
max_tokens: t.Optional[int] = attr.field(default=None)
|
||||
presence_penalty: t.Optional[float] = attr.field(default=None)
|
||||
frequency_penalty: t.Optional[float] = attr.field(default=None)
|
||||
@@ -80,7 +84,7 @@ class CompletionResponseChoice:
|
||||
index: int
|
||||
text: str
|
||||
logprobs: t.Optional[LogProbs] = None
|
||||
finish_reason: t.Optional[str] = None
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
@@ -88,7 +92,7 @@ class CompletionResponseStreamChoice:
|
||||
index: int
|
||||
text: str
|
||||
logprobs: t.Optional[LogProbs] = None
|
||||
finish_reason: t.Optional[str] = None
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
@@ -98,6 +102,7 @@ class CompletionStreamResponse:
|
||||
object: str = 'text_completion'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
usage: t.Optional[UsageInfo] = attr.field(default=None)
|
||||
|
||||
|
||||
@attr.define
|
||||
@@ -132,14 +137,14 @@ converter.register_unstructure_hook(ChatMessage, lambda msg: {'role': msg.role,
|
||||
class ChatCompletionResponseStreamChoice:
|
||||
index: int
|
||||
delta: Delta
|
||||
finish_reason: t.Optional[str] = attr.field(default=None)
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponseChoice:
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: t.Optional[str] = attr.field(default=None)
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
@@ -159,6 +164,7 @@ class ChatCompletionStreamResponse:
|
||||
object: str = 'chat.completion.chunk'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
usage: t.Optional[UsageInfo] = attr.field(default=None)
|
||||
|
||||
|
||||
@attr.define
|
||||
|
||||
Reference in New Issue
Block a user