mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-19 07:06:02 -05:00
feat(openai): supports echo (#760)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,6 @@ import functools
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
import typing as t
|
||||
from http import HTTPStatus
|
||||
|
||||
import orjson
|
||||
@@ -72,20 +71,25 @@ async def check_model(request, model):
|
||||
)
|
||||
|
||||
|
||||
def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm):
|
||||
def create_logprobs(token_ids, top_logprobs, num_output_top_logprobs=None, initial_text_offset=0, *, llm):
|
||||
# Create OpenAI-style logprobs.
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
for token_id, id_logprob in zip(token_ids, id_logprobs):
|
||||
if num_output_top_logprobs: logprobs.top_logprobs = []
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
token_logprob = None
|
||||
if step_top_logprobs is not None: token_logprob = step_top_logprobs[token_id]
|
||||
token = llm.tokenizer.convert_ids_to_tokens(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(id_logprob[token_id])
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in id_logprob.items()})
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in step_top_logprobs.items()} if step_top_logprobs else None)
|
||||
return logprobs
|
||||
|
||||
|
||||
@@ -140,12 +144,16 @@ async def chat_completions(req, llm):
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None: return err_check
|
||||
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
|
||||
created_time = int(time.monotonic())
|
||||
prompt = llm.tokenizer.apply_chat_template(request.messages, tokenize=False, chat_template=request.chat_template if request.chat_template != 'None' else None, add_generation_prompt=request.add_generation_prompt)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
def get_role() -> str: return request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant' # TODO: Support custom role here.
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
@@ -162,28 +170,36 @@ async def chat_completions(req, llm):
|
||||
if usage is not None: response.usage = usage
|
||||
return jsonify_attr(response)
|
||||
|
||||
async def chat_completion_stream_generator():
|
||||
async def completion_stream_generator():
|
||||
# first chunk with role
|
||||
for i in range(config['n']):
|
||||
yield f"data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role='assistant'), finish_reason=None)], model=model_name))}\n\n"
|
||||
role = get_role()
|
||||
for i in range(config['n']): yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role=role), finish_reason=None)], model=model_name))}\n\n'
|
||||
|
||||
if request.echo:
|
||||
last_message, last_content = request.messages[-1], ''
|
||||
if last_message.get('content') and last_message.get('role') == role: last_content = last_message['content']
|
||||
if last_content:
|
||||
for i in range(config['n']): yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(content=last_content), finish_reason=None)], model=model_name))}\n\n'
|
||||
|
||||
previous_num_tokens = [0] * config['n']
|
||||
finish_reason_sent = [False] * config['n']
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
if finish_reason_sent[output.index]: continue
|
||||
yield f'data: {create_stream_response_json(output.index, output.text)}\n\n'
|
||||
previous_num_tokens[output.index] += len(output.token_ids)
|
||||
if output.finish_reason is not None:
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
|
||||
yield f'data: {create_stream_response_json(output.index, "", output.finish_reason, usage)}\n\n'
|
||||
finish_reason_sent[output.index] = True
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if request.stream: return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream')
|
||||
if request.stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result = None
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
@@ -197,26 +213,27 @@ async def chat_completions(req, llm):
|
||||
for output in final_result.outputs
|
||||
]
|
||||
)
|
||||
|
||||
role = get_role()
|
||||
choices = [
|
||||
ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role='assistant', content=output.text),
|
||||
message=ChatMessage(role=role, content=output.text),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
for output in final_result.outputs
|
||||
]
|
||||
if request.echo:
|
||||
last_message, last_content = request.messages[-1], ''
|
||||
if last_message.get('content') and last_message.get('role') == role: last_content = last_message['content']
|
||||
for choice in choices:
|
||||
full_message = last_content + choice.message.content
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
|
||||
response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
|
||||
if request.stream:
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
async def fake_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n'
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc(); logger.error('Error generating completion: %s', err)
|
||||
@@ -237,7 +254,10 @@ async def completions(req, llm):
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None: return err_check
|
||||
|
||||
if request.echo: return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
|
||||
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
||||
echo_without_generation = request.echo and request.max_tokens == 0
|
||||
if echo_without_generation: request.max_tokens = 1 # XXX: Hack to make sure we get the prompt back.
|
||||
|
||||
if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
|
||||
@@ -272,15 +292,30 @@ async def completions(req, llm):
|
||||
async def completion_stream_generator():
|
||||
previous_num_tokens = [0] * config['n']
|
||||
previous_texts = [''] * config['n']
|
||||
previous_echo = [False] * config['n']
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
delta_text = output.text
|
||||
token_ids = output.token_ids
|
||||
top_logprobs = output.logprobs[previous_num_tokens[i]:]
|
||||
logprobs = None
|
||||
|
||||
if request.echo and not previous_echo[i]:
|
||||
if not echo_without_generation:
|
||||
delta_text = res.prompt + delta_text
|
||||
token_ids = res.prompt_token_ids + token_ids
|
||||
top_logprobs = res.prompt_logprobs + top_logprobs
|
||||
else:
|
||||
delta_text = res.prompt
|
||||
token_ids = res.prompt_token_ids
|
||||
top_logprobs = res.prompt_logprobs
|
||||
previous_echo[i] = True
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i]:], len(previous_texts[i]), llm=llm)
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i]:], request.logprobs, len(previous_texts[i]), llm=llm)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
previous_texts[i] += output.text
|
||||
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs)}\n\n'
|
||||
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)}\n\n'
|
||||
if output.finish_reason is not None:
|
||||
logprobs = LogProbs() if request.logprobs is not None else None
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
@@ -292,8 +327,7 @@ async def completions(req, llm):
|
||||
# Streaming case
|
||||
if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result = None
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
@@ -309,24 +343,30 @@ async def completions(req, llm):
|
||||
)
|
||||
|
||||
choices = []
|
||||
prompt_token_ids = final_result.prompt_token_ids
|
||||
prompt_logprobs = final_result.prompt_logprobs
|
||||
prompt_text = final_result.prompt
|
||||
for output in final_result.outputs:
|
||||
logprobs = None
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs, llm=llm)
|
||||
choice_data = CompletionResponseChoice(index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)
|
||||
if not echo_without_generation:
|
||||
token_ids, top_logprobs = output.token_ids, output.logprobs
|
||||
if request.echo: token_ids, top_logprobs = prompt_token_ids + token_ids, prompt_logprobs + top_logprobs
|
||||
else:
|
||||
token_ids, top_logprobs = prompt_token_ids, prompt_logprobs
|
||||
logprobs = create_logprobs(token_ids, top_logprobs, request.logprobs, llm=llm)
|
||||
if not echo_without_generation:
|
||||
output_text = output.text
|
||||
if request.echo: output_text = prompt_text + output_text
|
||||
else:
|
||||
output_text = prompt_text
|
||||
choice_data = CompletionResponseChoice(index=output.index, text=output_text, logprobs=logprobs, finish_reason=output.finish_reason)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
|
||||
response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
|
||||
if request.stream:
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
async def fake_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n'
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -18,7 +18,8 @@ async def check_model(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest], model: str
|
||||
) -> Optional[JSONResponse]: ...
|
||||
def create_logprobs(
|
||||
token_ids: List[int], id_logprobs: List[Dict[int, float]], initial_text_offset: int = ..., *, llm: LLM[M, T]
|
||||
token_ids: List[int], top_logprobs: List[Dict[int, float]], #
|
||||
num_output_top_logprobs: Optional[int] = ..., initial_text_offset: int = ..., *, llm: LLM[M, T]
|
||||
) -> LogProbs: ...
|
||||
def list_models(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
async def chat_completions(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
|
||||
Reference in New Issue
Block a user