feat(openai): chat templates and complete control of prompt generation (#725)

* feat(openai): chat templates and complete control of prompt generation

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: correctly use base chat templates

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: remove symlink

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-22 06:49:14 -05:00
committed by GitHub
parent 7aa0918a6f
commit b28b5269b5
11 changed files with 146 additions and 316 deletions

View File

@@ -56,24 +56,16 @@ schemas = get_generator(
logger = logging.getLogger(__name__)
def jsonify_attr(obj):
return orjson.dumps(converter.unstructure(obj)).decode()
def jsonify_attr(obj): return orjson.dumps(converter.unstructure(obj)).decode()
def error_response(status_code, message):
return JSONResponse(
{
'error': converter.unstructure(
ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value))
)
},
{'error': converter.unstructure(ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value)))},
status_code=status_code.value,
)
async def check_model(request, model):
if request.model == model:
return None
if request.model == model: return None
return error_response(
HTTPStatus.NOT_FOUND,
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
@@ -93,7 +85,6 @@ def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm):
else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
last_token_len = len(token)
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in id_logprob.items()})
return logprobs
@@ -106,7 +97,9 @@ def mount_to_svc(svc, llm):
debug=True,
routes=[
Route(
'/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']
'/models',
functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm),
methods=['GET']
),
Route(
'/completions',
@@ -115,7 +108,11 @@ def mount_to_svc(svc, llm):
),
Route(
'/chat/completions',
functools.partial(apply_schema(chat_completions, __model_id__=llm.llm_type), llm=llm),
functools.partial(apply_schema(chat_completions,
__model_id__=llm.llm_type,
__chat_template__=orjson.dumps(llm.config.chat_template).decode(),
__chat_messages__=orjson.dumps(llm.config.chat_messages).decode(),
__add_generation_prompt__=str(True) if llm.config.chat_messages is not None else str(False)), llm=llm),
methods=['POST'],
),
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
@@ -127,11 +124,7 @@ def mount_to_svc(svc, llm):
# GET /v1/models
@add_schema_definitions
def list_models(_, llm):
return JSONResponse(
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
)
def list_models(_, llm): return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value)
# POST /v1/chat/completions
@add_schema_definitions
@@ -141,27 +134,22 @@ async def chat_completions(req, llm):
try:
request = converter.structure(orjson.loads(json_str), ChatCompletionRequest)
except orjson.JSONDecodeError as err:
logger.debug('Sent body: %s', json_str)
logger.error('Invalid JSON input received: %s', err)
logger.debug('Sent body: %s', json_str); logger.error('Invalid JSON input received: %s', err)
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
logger.debug('Received chat completion request: %s', request)
err_check = await check_model(request, llm.llm_type)
if err_check is not None:
return err_check
if err_check is not None: return err_check
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
created_time = int(time.monotonic())
prompt = llm.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
)
prompt = llm.tokenizer.apply_chat_template(request.messages, tokenize=False, chat_template=request.chat_template if request.chat_template != 'None' else None, add_generation_prompt=request.add_generation_prompt)
logger.debug('Prompt: %r', prompt)
config = llm.config.compatible_options(request)
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
traceback.print_exc(); logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def create_stream_response_json(index, text, finish_reason=None, usage=None):
@@ -169,9 +157,7 @@ async def chat_completions(req, llm):
id=request_id,
created=created_time,
model=model_name,
choices=[
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
],
choices=[ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)],
)
if usage is not None: response.usage = usage
return jsonify_attr(response)
@@ -194,20 +180,17 @@ async def chat_completions(req, llm):
try:
# Streaming case
if request.stream:
return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream')
if request.stream: return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result = None
texts, token_ids = [[]] * config['n'], [[]] * config['n']
async for res in result_generator:
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
for output in res.outputs:
texts[output.index].append(output.text)
token_ids[output.index].extend(output.token_ids)
final_result = res
if final_result is None:
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
final_result = final_result.with_options(
outputs=[
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
@@ -225,25 +208,18 @@ async def chat_completions(req, llm):
num_prompt_tokens = len(final_result.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
response = ChatCompletionResponse(
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
)
response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
if request.stream: # type: ignore[unreachable]
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
async def fake_stream_generator() -> t.AsyncGenerator[str, None]: # type: ignore[unreachable]
yield f'data: {jsonify_attr(response)}\n\n'
yield 'data: [DONE]\n\n'
return StreamingResponse(
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
)
async def fake_stream_generator() -> t.AsyncGenerator[str, None]:
yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n'
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
traceback.print_exc(); logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
@@ -255,23 +231,17 @@ async def completions(req, llm):
try:
request = converter.structure(orjson.loads(json_str), CompletionRequest)
except orjson.JSONDecodeError as err:
logger.debug('Sent body: %s', json_str)
logger.error('Invalid JSON input received: %s', err)
logger.debug('Sent body: %s', json_str); logger.error('Invalid JSON input received: %s', err)
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
logger.debug('Received legacy completion request: %s', request)
err_check = await check_model(request, llm.llm_type)
if err_check is not None:
return err_check
if err_check is not None: return err_check
if request.echo:
return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
if request.suffix is not None:
return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
if request.logit_bias is not None and len(request.logit_bias) > 0:
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
if request.echo: return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
if not request.prompt:
return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
if not request.prompt: return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
prompt = request.prompt
# TODO: Support multiple prompts
@@ -282,8 +252,7 @@ async def completions(req, llm):
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
traceback.print_exc(); logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
# best_of != n then we don't stream
@@ -295,9 +264,7 @@ async def completions(req, llm):
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)
],
choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)],
)
if usage: response.usage = usage
return jsonify_attr(response)
@@ -308,12 +275,9 @@ async def completions(req, llm):
async for res in result_generator:
for output in res.outputs:
i = output.index
logprobs = None
if request.logprobs is not None:
logprobs = create_logprobs(
token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i]:], initial_text_offset=len(previous_texts[i]), llm=llm
)
else:
logprobs = None
logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i]:], len(previous_texts[i]), llm=llm)
previous_num_tokens[i] += len(output.token_ids)
previous_texts[i] += output.text
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs)}\n\n'
@@ -326,20 +290,17 @@ async def completions(req, llm):
try:
# Streaming case
if stream:
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result = None
texts, token_ids = [[]] * config['n'], [[]] * config['n']
async for res in result_generator:
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
for output in res.outputs:
texts[output.index].append(output.text)
token_ids[output.index].extend(output.token_ids)
final_result = res
if final_result is None:
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
final_result = final_result.with_options(
outputs=[
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
@@ -349,15 +310,10 @@ async def completions(req, llm):
choices = []
for output in final_result.outputs:
logprobs = None
if request.logprobs is not None:
logprobs = create_logprobs(
token_ids=output.token_ids, id_logprobs=output.logprobs, llm=llm
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason
)
logprobs = create_logprobs(output.token_ids, output.logprobs, llm=llm)
choice_data = CompletionResponseChoice(index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)
choices.append(choice_data)
num_prompt_tokens = len(final_result.prompt_token_ids)
@@ -369,13 +325,8 @@ async def completions(req, llm):
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
async def fake_stream_generator() -> t.AsyncGenerator[str, None]:
yield f'data: {jsonify_attr(response)}\n\n'
yield 'data: [DONE]\n\n'
return StreamingResponse(
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
)
yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n'
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err:
traceback.print_exc()