From c3a0b5c39f834e23d237ee968702b0bcb8fec395 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Sun, 10 Dec 2023 13:19:40 -0500 Subject: [PATCH] feat(openai): supports echo (#760) Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- .../src/openllm/entrypoints/openai.py | 108 ++++++++++++------ .../src/openllm/entrypoints/openai.pyi | 3 +- 2 files changed, 76 insertions(+), 35 deletions(-) diff --git a/openllm-python/src/openllm/entrypoints/openai.py b/openllm-python/src/openllm/entrypoints/openai.py index 225b7268..ae4894f2 100644 --- a/openllm-python/src/openllm/entrypoints/openai.py +++ b/openllm-python/src/openllm/entrypoints/openai.py @@ -2,7 +2,6 @@ import functools import logging import time import traceback -import typing as t from http import HTTPStatus import orjson @@ -72,20 +71,25 @@ async def check_model(request, model): ) -def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm): +def create_logprobs(token_ids, top_logprobs, num_output_top_logprobs=None, initial_text_offset=0, *, llm): # Create OpenAI-style logprobs. logprobs = LogProbs() last_token_len = 0 - for token_id, id_logprob in zip(token_ids, id_logprobs): + if num_output_top_logprobs: logprobs.top_logprobs = [] + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + token_logprob = None + if step_top_logprobs is not None: token_logprob = step_top_logprobs[token_id] token = llm.tokenizer.convert_ids_to_tokens(token_id) logprobs.tokens.append(token) - logprobs.token_logprobs.append(id_logprob[token_id]) + logprobs.token_logprobs.append(token_logprob) if len(logprobs.text_offset) == 0: logprobs.text_offset.append(initial_text_offset) else: logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) last_token_len = len(token) - logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in id_logprob.items()}) + if num_output_top_logprobs: + logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in step_top_logprobs.items()} if step_top_logprobs else None) return logprobs @@ -140,12 +144,16 @@ async def chat_completions(req, llm): err_check = await check_model(request, llm.llm_type) if err_check is not None: return err_check + if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.") + model_name, request_id = request.model, gen_random_uuid('chatcmpl') created_time = int(time.monotonic()) prompt = llm.tokenizer.apply_chat_template(request.messages, tokenize=False, chat_template=request.chat_template if request.chat_template != 'None' else None, add_generation_prompt=request.add_generation_prompt) logger.debug('Prompt: %r', prompt) config = llm.config.compatible_options(request) + def get_role() -> str: return request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant' # TODO: Support custom role here. + try: result_generator = llm.generate_iterator(prompt, request_id=request_id, **config) except Exception as err: @@ -162,28 +170,36 @@ async def chat_completions(req, llm): if usage is not None: response.usage = usage return jsonify_attr(response) - async def chat_completion_stream_generator(): + async def completion_stream_generator(): # first chunk with role - for i in range(config['n']): - yield f"data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role='assistant'), finish_reason=None)], model=model_name))}\n\n" + role = get_role() + for i in range(config['n']): yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role=role), finish_reason=None)], model=model_name))}\n\n' + + if request.echo: + last_message, last_content = request.messages[-1], '' + if last_message.get('content') and last_message.get('role') == role: last_content = last_message['content'] + if last_content: + for i in range(config['n']): yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(content=last_content), finish_reason=None)], model=model_name))}\n\n' previous_num_tokens = [0] * config['n'] + finish_reason_sent = [False] * config['n'] async for res in result_generator: for output in res.outputs: + if finish_reason_sent[output.index]: continue yield f'data: {create_stream_response_json(output.index, output.text)}\n\n' previous_num_tokens[output.index] += len(output.token_ids) if output.finish_reason is not None: prompt_tokens = len(res.prompt_token_ids) usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i]) yield f'data: {create_stream_response_json(output.index, "", output.finish_reason, usage)}\n\n' + finish_reason_sent[output.index] = True yield 'data: [DONE]\n\n' try: # Streaming case - if request.stream: return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream') + if request.stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') # Non-streaming case - final_result = None - texts, token_ids = [[]] * config['n'], [[]] * config['n'] + final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n'] async for res in result_generator: if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') for output in res.outputs: @@ -197,26 +213,27 @@ async def chat_completions(req, llm): for output in final_result.outputs ] ) + + role = get_role() choices = [ ChatCompletionResponseChoice( index=output.index, - message=ChatMessage(role='assistant', content=output.text), + message=ChatMessage(role=role, content=output.text), finish_reason=output.finish_reason, ) for output in final_result.outputs ] + if request.echo: + last_message, last_content = request.messages[-1], '' + if last_message.get('content') and last_message.get('role') == role: last_content = last_message['content'] + for choice in choices: + full_message = last_content + choice.message.content + choice.message.content = full_message + num_prompt_tokens = len(final_result.prompt_token_ids) num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs) usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens) response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices) - - if request.stream: - # When user requests streaming but we don't stream, we still need to - # return a streaming response with a single event. - async def fake_stream_generator() -> t.AsyncGenerator[str, None]: - yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n' - return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value) - return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value) except Exception as err: traceback.print_exc(); logger.error('Error generating completion: %s', err) @@ -237,7 +254,10 @@ async def completions(req, llm): err_check = await check_model(request, llm.llm_type) if err_check is not None: return err_check - if request.echo: return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.") + # OpenAI API supports echoing the prompt when max_tokens is 0. + echo_without_generation = request.echo and request.max_tokens == 0 + if echo_without_generation: request.max_tokens = 1 # XXX: Hack to make sure we get the prompt back. + if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.") if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.") @@ -272,15 +292,30 @@ async def completions(req, llm): async def completion_stream_generator(): previous_num_tokens = [0] * config['n'] previous_texts = [''] * config['n'] + previous_echo = [False] * config['n'] async for res in result_generator: for output in res.outputs: i = output.index + delta_text = output.text + token_ids = output.token_ids + top_logprobs = output.logprobs[previous_num_tokens[i]:] logprobs = None + + if request.echo and not previous_echo[i]: + if not echo_without_generation: + delta_text = res.prompt + delta_text + token_ids = res.prompt_token_ids + token_ids + top_logprobs = res.prompt_logprobs + top_logprobs + else: + delta_text = res.prompt + token_ids = res.prompt_token_ids + top_logprobs = res.prompt_logprobs + previous_echo[i] = True if request.logprobs is not None: - logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i]:], len(previous_texts[i]), llm=llm) + logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i]:], request.logprobs, len(previous_texts[i]), llm=llm) previous_num_tokens[i] += len(output.token_ids) previous_texts[i] += output.text - yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs)}\n\n' + yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)}\n\n' if output.finish_reason is not None: logprobs = LogProbs() if request.logprobs is not None else None prompt_tokens = len(res.prompt_token_ids) @@ -292,8 +327,7 @@ async def completions(req, llm): # Streaming case if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') # Non-streaming case - final_result = None - texts, token_ids = [[]] * config['n'], [[]] * config['n'] + final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n'] async for res in result_generator: if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') for output in res.outputs: @@ -309,24 +343,30 @@ async def completions(req, llm): ) choices = [] + prompt_token_ids = final_result.prompt_token_ids + prompt_logprobs = final_result.prompt_logprobs + prompt_text = final_result.prompt for output in final_result.outputs: logprobs = None if request.logprobs is not None: - logprobs = create_logprobs(output.token_ids, output.logprobs, llm=llm) - choice_data = CompletionResponseChoice(index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason) + if not echo_without_generation: + token_ids, top_logprobs = output.token_ids, output.logprobs + if request.echo: token_ids, top_logprobs = prompt_token_ids + token_ids, prompt_logprobs + top_logprobs + else: + token_ids, top_logprobs = prompt_token_ids, prompt_logprobs + logprobs = create_logprobs(token_ids, top_logprobs, request.logprobs, llm=llm) + if not echo_without_generation: + output_text = output.text + if request.echo: output_text = prompt_text + output_text + else: + output_text = prompt_text + choice_data = CompletionResponseChoice(index=output.index, text=output_text, logprobs=logprobs, finish_reason=output.finish_reason) choices.append(choice_data) num_prompt_tokens = len(final_result.prompt_token_ids) num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs) usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens) response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices) - - if request.stream: - # When user requests streaming but we don't stream, we still need to - # return a streaming response with a single event. - async def fake_stream_generator() -> t.AsyncGenerator[str, None]: - yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n' - return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value) return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value) except Exception as err: traceback.print_exc() diff --git a/openllm-python/src/openllm/entrypoints/openai.pyi b/openllm-python/src/openllm/entrypoints/openai.pyi index 0935c227..110606bf 100644 --- a/openllm-python/src/openllm/entrypoints/openai.pyi +++ b/openllm-python/src/openllm/entrypoints/openai.pyi @@ -18,7 +18,8 @@ async def check_model( request: Union[CompletionRequest, ChatCompletionRequest], model: str ) -> Optional[JSONResponse]: ... def create_logprobs( - token_ids: List[int], id_logprobs: List[Dict[int, float]], initial_text_offset: int = ..., *, llm: LLM[M, T] + token_ids: List[int], top_logprobs: List[Dict[int, float]], # + num_output_top_logprobs: Optional[int] = ..., initial_text_offset: int = ..., *, llm: LLM[M, T] ) -> LogProbs: ... def list_models(req: Request, llm: LLM[M, T]) -> Response: ... async def chat_completions(req: Request, llm: LLM[M, T]) -> Response: ...