mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-12 02:20:32 -04:00
feat(openai): chat templates and complete control of prompt generation (#725)
* feat(openai): chat templates and complete control of prompt generation Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * fix: correctly use base chat templates Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * fix: remove symlink Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -56,24 +56,16 @@ schemas = get_generator(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def jsonify_attr(obj):
|
||||
return orjson.dumps(converter.unstructure(obj)).decode()
|
||||
|
||||
def jsonify_attr(obj): return orjson.dumps(converter.unstructure(obj)).decode()
|
||||
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(
|
||||
{
|
||||
'error': converter.unstructure(
|
||||
ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value))
|
||||
)
|
||||
},
|
||||
{'error': converter.unstructure(ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value)))},
|
||||
status_code=status_code.value,
|
||||
)
|
||||
|
||||
|
||||
async def check_model(request, model):
|
||||
if request.model == model:
|
||||
return None
|
||||
if request.model == model: return None
|
||||
return error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
@@ -93,7 +85,6 @@ def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm):
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in id_logprob.items()})
|
||||
return logprobs
|
||||
|
||||
@@ -106,7 +97,9 @@ def mount_to_svc(svc, llm):
|
||||
debug=True,
|
||||
routes=[
|
||||
Route(
|
||||
'/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']
|
||||
'/models',
|
||||
functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm),
|
||||
methods=['GET']
|
||||
),
|
||||
Route(
|
||||
'/completions',
|
||||
@@ -115,7 +108,11 @@ def mount_to_svc(svc, llm):
|
||||
),
|
||||
Route(
|
||||
'/chat/completions',
|
||||
functools.partial(apply_schema(chat_completions, __model_id__=llm.llm_type), llm=llm),
|
||||
functools.partial(apply_schema(chat_completions,
|
||||
__model_id__=llm.llm_type,
|
||||
__chat_template__=orjson.dumps(llm.config.chat_template).decode(),
|
||||
__chat_messages__=orjson.dumps(llm.config.chat_messages).decode(),
|
||||
__add_generation_prompt__=str(True) if llm.config.chat_messages is not None else str(False)), llm=llm),
|
||||
methods=['POST'],
|
||||
),
|
||||
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
|
||||
@@ -127,11 +124,7 @@ def mount_to_svc(svc, llm):
|
||||
|
||||
# GET /v1/models
|
||||
@add_schema_definitions
|
||||
def list_models(_, llm):
|
||||
return JSONResponse(
|
||||
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
def list_models(_, llm): return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value)
|
||||
|
||||
# POST /v1/chat/completions
|
||||
@add_schema_definitions
|
||||
@@ -141,27 +134,22 @@ async def chat_completions(req, llm):
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), ChatCompletionRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
logger.debug('Sent body: %s', json_str); logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received chat completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
if err_check is not None: return err_check
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
|
||||
created_time = int(time.monotonic())
|
||||
prompt = llm.tokenizer.apply_chat_template(
|
||||
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
prompt = llm.tokenizer.apply_chat_template(request.messages, tokenize=False, chat_template=request.chat_template if request.chat_template != 'None' else None, add_generation_prompt=request.add_generation_prompt)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
traceback.print_exc(); logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_response_json(index, text, finish_reason=None, usage=None):
|
||||
@@ -169,9 +157,7 @@ async def chat_completions(req, llm):
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
|
||||
],
|
||||
choices=[ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)],
|
||||
)
|
||||
if usage is not None: response.usage = usage
|
||||
return jsonify_attr(response)
|
||||
@@ -194,20 +180,17 @@ async def chat_completions(req, llm):
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if request.stream:
|
||||
return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream')
|
||||
if request.stream: return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result = None
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
@@ -225,25 +208,18 @@ async def chat_completions(req, llm):
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
|
||||
)
|
||||
response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
|
||||
if request.stream: # type: ignore[unreachable]
|
||||
if request.stream:
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
async def fake_stream_generator() -> t.AsyncGenerator[str, None]: # type: ignore[unreachable]
|
||||
yield f'data: {jsonify_attr(response)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
return StreamingResponse(
|
||||
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
|
||||
)
|
||||
async def fake_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n'
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
traceback.print_exc(); logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
@@ -255,23 +231,17 @@ async def completions(req, llm):
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), CompletionRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
logger.debug('Sent body: %s', json_str); logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received legacy completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
if err_check is not None: return err_check
|
||||
|
||||
if request.echo:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
|
||||
if request.suffix is not None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
if request.echo: return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
|
||||
if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
|
||||
if not request.prompt:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
if not request.prompt: return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
prompt = request.prompt
|
||||
# TODO: Support multiple prompts
|
||||
|
||||
@@ -282,8 +252,7 @@ async def completions(req, llm):
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
traceback.print_exc(); logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
# best_of != n then we don't stream
|
||||
@@ -295,9 +264,7 @@ async def completions(req, llm):
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)
|
||||
],
|
||||
choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)],
|
||||
)
|
||||
if usage: response.usage = usage
|
||||
return jsonify_attr(response)
|
||||
@@ -308,12 +275,9 @@ async def completions(req, llm):
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
logprobs = None
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i]:], initial_text_offset=len(previous_texts[i]), llm=llm
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i]:], len(previous_texts[i]), llm=llm)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
previous_texts[i] += output.text
|
||||
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs)}\n\n'
|
||||
@@ -326,20 +290,17 @@ async def completions(req, llm):
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result = None
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
@@ -349,15 +310,10 @@ async def completions(req, llm):
|
||||
|
||||
choices = []
|
||||
for output in final_result.outputs:
|
||||
logprobs = None
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
token_ids=output.token_ids, id_logprobs=output.logprobs, llm=llm
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason
|
||||
)
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs, llm=llm)
|
||||
choice_data = CompletionResponseChoice(index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
@@ -369,13 +325,8 @@ async def completions(req, llm):
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
async def fake_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
yield f'data: {jsonify_attr(response)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
return StreamingResponse(
|
||||
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
yield f'data: {jsonify_attr(response)}\n\n'; yield 'data: [DONE]\n\n'
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
|
||||
Reference in New Issue
Block a user