feat(openai): supports echo (#760)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-12-10 13:19:40 -05:00
committed by GitHub
parent bb4ed8b53c
commit c3a0b5c39f
2 changed files with 76 additions and 35 deletions

View File

@@ -2,7 +2,6 @@ import functools
import logging
import time
import traceback
import typing as t
from http import HTTPStatus
import orjson
@@ -72,20 +71,25 @@ async def check_model(request, model):
)
def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm):
def create_logprobs(token_ids, top_logprobs, num_output_top_logprobs=None, initial_text_offset=0, *, llm):
# Create OpenAI-style logprobs.
logprobs = LogProbs()
last_token_len = 0
for token_id, id_logprob in zip(token_ids, id_logprobs):
if num_output_top_logprobs: logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
token_logprob = None
if step_top_logprobs is not None: token_logprob = step_top_logprobs[token_id]
token = llm.tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(id_logprob[token_id])
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
last_token_len = len(token)
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in id_logprob.items()})
if num_output_top_logprobs:
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in step_top_logprobs.items()} if step_top_logprobs else None)
return logprobs
@@ -140,12 +144,16 @@ async def chat_completions(req, llm):
err_check = await check_model(request, llm.llm_type)
if err_check is not None: return err_check
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
created_time = int(time.monotonic())
prompt = llm.tokenizer.apply_chat_template(request.messages, tokenize=False, chat_template=request.chat_template if request.chat_template != 'None' else None, add_generation_prompt=request.add_generation_prompt)
logger.debug('Prompt: %r', prompt)
config = llm.config.compatible_options(request)
def get_role() -> str: return request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant' # TODO: Support custom role here.
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
except Exception as err:
@@ -162,28 +170,36 @@ async def chat_completions(req, llm):
if usage is not None: response.usage = usage
return jsonify_attr(response)
async def chat_completion_stream_generator():
async def completion_stream_generator():
# first chunk with role
for i in range(config['n']):
yield f"data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role='assistant'), finish_reason=None)], model=model_name))}\n\n"
role = get_role()
for i in range(config['n']): yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role=role), finish_reason=None)], model=model_name))}\n\n'
if request.echo:
last_message, last_content = request.messages[-1], ''
if last_message.get('content') and last_message.get('role') == role: last_content = last_message['content']
if last_content:
for i in range(config['n']): yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(content=last_content), finish_reason=None)], model=model_name))}\n\n'
previous_num_tokens = [0] * config['n']
finish_reason_sent = [False] * config['n']
async for res in result_generator:
for output in res.outputs:
if finish_reason_sent[output.index]: continue
yield f'data: {create_stream_response_json(output.index, output.text)}\n\n'
previous_num_tokens[output.index] += len(output.token_ids)
if output.finish_reason is not None:
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
yield f'data: {create_stream_response_json(output.index, "", output.finish_reason, usage)}\n\n'
finish_reason_sent[output.index] = True
yield 'data: [DONE]\n\n'
try:
# Streaming case
if request.stream: return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream')
if request.stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result = None
texts, token_ids = [[]] * config['n'], [[]] * config['n']
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
async for res in result_generator:
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
for output in res.outputs:
@@ -197,26 +213,27 @@ async def chat_completions(req, llm):
for output in final_result.outputs
]
)
role = get_role()
choices = [
ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role='assistant', content=output.text),
message=ChatMessage(role=role, content=output.text),
finish_reason=output.finish_reason,
)
for output in final_result.outputs
]
if request.echo:
last_message, last_content = request.messages[-1], ''
if last_message.get('content') and last_message.get('role') == role: last_content = last_message['content']
for choice in choices:
full_message = last_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(final_result.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
async def fake_stream_generator() -> t.AsyncGenerator[str, None]:
yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n'
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err:
traceback.print_exc(); logger.error('Error generating completion: %s', err)
@@ -237,7 +254,10 @@ async def completions(req, llm):
err_check = await check_model(request, llm.llm_type)
if err_check is not None: return err_check
if request.echo: return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation = request.echo and request.max_tokens == 0
if echo_without_generation: request.max_tokens = 1 # XXX: Hack to make sure we get the prompt back.
if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
@@ -272,15 +292,30 @@ async def completions(req, llm):
async def completion_stream_generator():
previous_num_tokens = [0] * config['n']
previous_texts = [''] * config['n']
previous_echo = [False] * config['n']
async for res in result_generator:
for output in res.outputs:
i = output.index
delta_text = output.text
token_ids = output.token_ids
top_logprobs = output.logprobs[previous_num_tokens[i]:]
logprobs = None
if request.echo and not previous_echo[i]:
if not echo_without_generation:
delta_text = res.prompt + delta_text
token_ids = res.prompt_token_ids + token_ids
top_logprobs = res.prompt_logprobs + top_logprobs
else:
delta_text = res.prompt
token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
previous_echo[i] = True
if request.logprobs is not None:
logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i]:], len(previous_texts[i]), llm=llm)
logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i]:], request.logprobs, len(previous_texts[i]), llm=llm)
previous_num_tokens[i] += len(output.token_ids)
previous_texts[i] += output.text
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs)}\n\n'
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)}\n\n'
if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None
prompt_tokens = len(res.prompt_token_ids)
@@ -292,8 +327,7 @@ async def completions(req, llm):
# Streaming case
if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result = None
texts, token_ids = [[]] * config['n'], [[]] * config['n']
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
async for res in result_generator:
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
for output in res.outputs:
@@ -309,24 +343,30 @@ async def completions(req, llm):
)
choices = []
prompt_token_ids = final_result.prompt_token_ids
prompt_logprobs = final_result.prompt_logprobs
prompt_text = final_result.prompt
for output in final_result.outputs:
logprobs = None
if request.logprobs is not None:
logprobs = create_logprobs(output.token_ids, output.logprobs, llm=llm)
choice_data = CompletionResponseChoice(index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)
if not echo_without_generation:
token_ids, top_logprobs = output.token_ids, output.logprobs
if request.echo: token_ids, top_logprobs = prompt_token_ids + token_ids, prompt_logprobs + top_logprobs
else:
token_ids, top_logprobs = prompt_token_ids, prompt_logprobs
logprobs = create_logprobs(token_ids, top_logprobs, request.logprobs, llm=llm)
if not echo_without_generation:
output_text = output.text
if request.echo: output_text = prompt_text + output_text
else:
output_text = prompt_text
choice_data = CompletionResponseChoice(index=output.index, text=output_text, logprobs=logprobs, finish_reason=output.finish_reason)
choices.append(choice_data)
num_prompt_tokens = len(final_result.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
async def fake_stream_generator() -> t.AsyncGenerator[str, None]:
yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n'
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err:
traceback.print_exc()

View File

@@ -18,7 +18,8 @@ async def check_model(
request: Union[CompletionRequest, ChatCompletionRequest], model: str
) -> Optional[JSONResponse]: ...
def create_logprobs(
token_ids: List[int], id_logprobs: List[Dict[int, float]], initial_text_offset: int = ..., *, llm: LLM[M, T]
token_ids: List[int], top_logprobs: List[Dict[int, float]], #
num_output_top_logprobs: Optional[int] = ..., initial_text_offset: int = ..., *, llm: LLM[M, T]
) -> LogProbs: ...
def list_models(req: Request, llm: LLM[M, T]) -> Response: ...
async def chat_completions(req: Request, llm: LLM[M, T]) -> Response: ...