mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-12 02:20:32 -04:00
* ci: pre-commit autoupdate [pre-commit.ci] updates: - [github.com/astral-sh/ruff-pre-commit: v0.2.2 → v0.3.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.2.2...v0.3.2) - [github.com/pre-commit/mirrors-eslint: v9.0.0-beta.0 → v9.0.0-beta.2](https://github.com/pre-commit/mirrors-eslint/compare/v9.0.0-beta.0...v9.0.0-beta.2) * ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
445 lines
17 KiB
Python
445 lines
17 KiB
Python
import functools
|
|
import logging
|
|
import time
|
|
import traceback
|
|
from http import HTTPStatus
|
|
|
|
import orjson
|
|
from starlette.applications import Starlette
|
|
from starlette.responses import JSONResponse, StreamingResponse
|
|
from starlette.routing import Route
|
|
|
|
from openllm_core.utils import converter, gen_random_uuid
|
|
|
|
from ._openapi import add_schema_definitions, append_schemas, apply_schema, get_generator
|
|
from ..protocol.openai import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseChoice,
|
|
ChatCompletionResponseStreamChoice,
|
|
ChatCompletionStreamResponse,
|
|
ChatMessage,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseChoice,
|
|
CompletionResponseStreamChoice,
|
|
CompletionStreamResponse,
|
|
Delta,
|
|
ErrorResponse,
|
|
LogProbs,
|
|
ModelCard,
|
|
ModelList,
|
|
UsageInfo,
|
|
)
|
|
|
|
schemas = get_generator(
|
|
'openai',
|
|
components=[
|
|
ErrorResponse,
|
|
ModelList,
|
|
ChatCompletionResponse,
|
|
ChatCompletionRequest,
|
|
ChatCompletionStreamResponse,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionStreamResponse,
|
|
],
|
|
tags=[
|
|
{
|
|
'name': 'OpenAI',
|
|
'description': 'OpenAI Compatible API support',
|
|
'externalDocs': 'https://platform.openai.com/docs/api-reference/completions/object',
|
|
}
|
|
],
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def jsonify_attr(obj):
|
|
return orjson.dumps(converter.unstructure(obj)).decode()
|
|
|
|
|
|
def error_response(status_code, message):
|
|
return JSONResponse(
|
|
{
|
|
'error': converter.unstructure(
|
|
ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value))
|
|
)
|
|
},
|
|
status_code=status_code.value,
|
|
)
|
|
|
|
|
|
async def check_model(request, model):
|
|
if request.model == model:
|
|
return None
|
|
return error_response(
|
|
HTTPStatus.NOT_FOUND,
|
|
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
|
)
|
|
|
|
|
|
def create_logprobs(token_ids, top_logprobs, num_output_top_logprobs=None, initial_text_offset=0, *, llm):
|
|
# Create OpenAI-style logprobs.
|
|
logprobs = LogProbs()
|
|
last_token_len = 0
|
|
if num_output_top_logprobs:
|
|
logprobs.top_logprobs = []
|
|
for i, token_id in enumerate(token_ids):
|
|
step_top_logprobs = top_logprobs[i]
|
|
token_logprob = None
|
|
if step_top_logprobs is not None:
|
|
token_logprob = step_top_logprobs[token_id]
|
|
token = llm.tokenizer.convert_ids_to_tokens(token_id)
|
|
logprobs.tokens.append(token)
|
|
logprobs.token_logprobs.append(token_logprob)
|
|
if len(logprobs.text_offset) == 0:
|
|
logprobs.text_offset.append(initial_text_offset)
|
|
else:
|
|
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
|
last_token_len = len(token)
|
|
if num_output_top_logprobs:
|
|
logprobs.top_logprobs.append(
|
|
{llm.tokenizer.convert_ids_to_tokens(i): p for i, p in step_top_logprobs.items()}
|
|
if step_top_logprobs
|
|
else None
|
|
)
|
|
return logprobs
|
|
|
|
|
|
def mount_to_svc(svc, llm):
|
|
list_models.__doc__ = list_models.__doc__.replace('__model_id__', llm.llm_type)
|
|
completions.__doc__ = completions.__doc__.replace('__model_id__', llm.llm_type)
|
|
chat_completions.__doc__ = chat_completions.__doc__.replace('__model_id__', llm.llm_type)
|
|
app = Starlette(
|
|
debug=True,
|
|
routes=[
|
|
Route(
|
|
'/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']
|
|
),
|
|
Route(
|
|
'/completions',
|
|
functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm),
|
|
methods=['POST'],
|
|
),
|
|
Route(
|
|
'/chat/completions',
|
|
functools.partial(
|
|
apply_schema(
|
|
chat_completions,
|
|
__model_id__=llm.llm_type,
|
|
__chat_template__=orjson.dumps(llm.config.chat_template).decode(),
|
|
__chat_messages__=orjson.dumps(llm.config.chat_messages).decode(),
|
|
__add_generation_prompt__=str(True) if llm.config.chat_messages is not None else str(False),
|
|
),
|
|
llm=llm,
|
|
),
|
|
methods=['POST'],
|
|
),
|
|
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
|
|
],
|
|
)
|
|
svc.mount_asgi_app(app, path='/v1')
|
|
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path='/v1'))
|
|
|
|
|
|
# GET /v1/models
|
|
@add_schema_definitions
|
|
def list_models(_, llm):
|
|
return JSONResponse(
|
|
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
|
|
)
|
|
|
|
|
|
# POST /v1/chat/completions
|
|
@add_schema_definitions
|
|
async def chat_completions(req, llm):
|
|
# TODO: Check for length based on model context_length
|
|
json_str = await req.body()
|
|
try:
|
|
request = converter.structure(orjson.loads(json_str), ChatCompletionRequest)
|
|
except orjson.JSONDecodeError as err:
|
|
logger.debug('Sent body: %s', json_str)
|
|
logger.error('Invalid JSON input received: %s', err)
|
|
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
|
logger.debug('Received chat completion request: %s', request)
|
|
err_check = await check_model(request, llm.llm_type)
|
|
if err_check is not None:
|
|
return err_check
|
|
|
|
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
|
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
|
|
|
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
|
|
created_time = int(time.monotonic())
|
|
prompt = llm.tokenizer.apply_chat_template(
|
|
request.messages,
|
|
tokenize=False,
|
|
chat_template=request.chat_template if request.chat_template != 'None' else None,
|
|
add_generation_prompt=request.add_generation_prompt,
|
|
)
|
|
logger.debug('Prompt: %r', prompt)
|
|
config = llm.config.compatible_options(request)
|
|
|
|
def get_role() -> str:
|
|
return (
|
|
request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant'
|
|
) # TODO: Support custom role here.
|
|
|
|
try:
|
|
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
|
except Exception as err:
|
|
traceback.print_exc()
|
|
logger.error('Error generating completion: %s', err)
|
|
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
|
|
|
def create_stream_response_json(index, text, finish_reason=None, usage=None):
|
|
response = ChatCompletionStreamResponse(
|
|
id=request_id,
|
|
created=created_time,
|
|
model=model_name,
|
|
choices=[
|
|
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
|
|
],
|
|
)
|
|
if usage is not None:
|
|
response.usage = usage
|
|
return jsonify_attr(response)
|
|
|
|
async def completion_stream_generator():
|
|
# first chunk with role
|
|
role = get_role()
|
|
for i in range(config['n']):
|
|
yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role=role), finish_reason=None)], model=model_name))}\n\n'
|
|
|
|
if request.echo:
|
|
last_message, last_content = request.messages[-1], ''
|
|
if last_message.get('content') and last_message.get('role') == role:
|
|
last_content = last_message['content']
|
|
if last_content:
|
|
for i in range(config['n']):
|
|
yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(content=last_content), finish_reason=None)], model=model_name))}\n\n'
|
|
|
|
previous_num_tokens = [0] * config['n']
|
|
finish_reason_sent = [False] * config['n']
|
|
async for res in result_generator:
|
|
for output in res.outputs:
|
|
if finish_reason_sent[output.index]:
|
|
continue
|
|
yield f'data: {create_stream_response_json(output.index, output.text)}\n\n'
|
|
previous_num_tokens[output.index] += len(output.token_ids)
|
|
if output.finish_reason is not None:
|
|
prompt_tokens = len(res.prompt_token_ids)
|
|
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
|
|
yield f'data: {create_stream_response_json(output.index, "", output.finish_reason, usage)}\n\n'
|
|
finish_reason_sent[output.index] = True
|
|
yield 'data: [DONE]\n\n'
|
|
|
|
try:
|
|
# Streaming case
|
|
if request.stream:
|
|
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
|
# Non-streaming case
|
|
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
|
|
async for res in result_generator:
|
|
if await req.is_disconnected():
|
|
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
|
for output in res.outputs:
|
|
texts[output.index].append(output.text)
|
|
token_ids[output.index].extend(output.token_ids)
|
|
final_result = res
|
|
if final_result is None:
|
|
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
|
final_result = final_result.with_options(
|
|
outputs=[
|
|
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
|
for output in final_result.outputs
|
|
]
|
|
)
|
|
|
|
role = get_role()
|
|
choices = [
|
|
ChatCompletionResponseChoice(
|
|
index=output.index, message=ChatMessage(role=role, content=output.text), finish_reason=output.finish_reason
|
|
)
|
|
for output in final_result.outputs
|
|
]
|
|
if request.echo:
|
|
last_message, last_content = request.messages[-1], ''
|
|
if last_message.get('content') and last_message.get('role') == role:
|
|
last_content = last_message['content']
|
|
for choice in choices:
|
|
full_message = last_content + choice.message.content
|
|
choice.message.content = full_message
|
|
|
|
num_prompt_tokens = len(final_result.prompt_token_ids)
|
|
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
|
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
|
|
response = ChatCompletionResponse(
|
|
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
|
|
)
|
|
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
|
except Exception as err:
|
|
traceback.print_exc()
|
|
logger.error('Error generating completion: %s', err)
|
|
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
|
|
|
|
|
# POST /v1/completions
|
|
@add_schema_definitions
|
|
async def completions(req, llm):
|
|
# TODO: Check for length based on model context_length
|
|
json_str = await req.body()
|
|
try:
|
|
request = converter.structure(orjson.loads(json_str), CompletionRequest)
|
|
except orjson.JSONDecodeError as err:
|
|
logger.debug('Sent body: %s', json_str)
|
|
logger.error('Invalid JSON input received: %s', err)
|
|
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
|
logger.debug('Received legacy completion request: %s', request)
|
|
err_check = await check_model(request, llm.llm_type)
|
|
if err_check is not None:
|
|
return err_check
|
|
|
|
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
|
echo_without_generation = request.echo and request.max_tokens == 0
|
|
if echo_without_generation:
|
|
request.max_tokens = 1 # XXX: Hack to make sure we get the prompt back.
|
|
|
|
if request.suffix is not None:
|
|
return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
|
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
|
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
|
|
|
if not request.prompt:
|
|
return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
|
prompt = request.prompt
|
|
# TODO: Support multiple prompts
|
|
|
|
model_name, request_id = request.model, gen_random_uuid('cmpl')
|
|
created_time = int(time.monotonic())
|
|
config = llm.config.compatible_options(request)
|
|
|
|
try:
|
|
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
|
except Exception as err:
|
|
traceback.print_exc()
|
|
logger.error('Error generating completion: %s', err)
|
|
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
|
|
|
# best_of != n then we don't stream
|
|
# TODO: support use_beam_search
|
|
stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of'])
|
|
|
|
def create_stream_response_json(index, text, logprobs=None, finish_reason=None, usage=None):
|
|
response = CompletionStreamResponse(
|
|
id=request_id,
|
|
created=created_time,
|
|
model=model_name,
|
|
choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)],
|
|
)
|
|
if usage:
|
|
response.usage = usage
|
|
return jsonify_attr(response)
|
|
|
|
async def completion_stream_generator():
|
|
previous_num_tokens = [0] * config['n']
|
|
previous_texts = [''] * config['n']
|
|
previous_echo = [False] * config['n']
|
|
async for res in result_generator:
|
|
for output in res.outputs:
|
|
i = output.index
|
|
delta_text = output.text
|
|
token_ids = output.token_ids
|
|
logprobs = None
|
|
top_logprobs = None
|
|
if request.logprobs is not None:
|
|
top_logprobs = output.logprobs[previous_num_tokens[i] :]
|
|
|
|
if request.echo and not previous_echo[i]:
|
|
if not echo_without_generation:
|
|
delta_text = res.prompt + delta_text
|
|
token_ids = res.prompt_token_ids + token_ids
|
|
if top_logprobs:
|
|
top_logprobs = res.prompt_logprobs + top_logprobs
|
|
else:
|
|
delta_text = res.prompt
|
|
token_ids = res.prompt_token_ids
|
|
if top_logprobs:
|
|
top_logprobs = res.prompt_logprobs
|
|
previous_echo[i] = True
|
|
if request.logprobs is not None:
|
|
logprobs = create_logprobs(
|
|
output.token_ids,
|
|
output.logprobs[previous_num_tokens[i] :],
|
|
request.logprobs,
|
|
len(previous_texts[i]),
|
|
llm=llm,
|
|
)
|
|
previous_num_tokens[i] += len(output.token_ids)
|
|
previous_texts[i] += output.text
|
|
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)}\n\n'
|
|
if output.finish_reason is not None:
|
|
logprobs = LogProbs() if request.logprobs is not None else None
|
|
prompt_tokens = len(res.prompt_token_ids)
|
|
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
|
|
yield f'data: {create_stream_response_json(i, "", logprobs, output.finish_reason, usage)}\n\n'
|
|
yield 'data: [DONE]\n\n'
|
|
|
|
try:
|
|
# Streaming case
|
|
if stream:
|
|
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
|
# Non-streaming case
|
|
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
|
|
async for res in result_generator:
|
|
if await req.is_disconnected():
|
|
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
|
for output in res.outputs:
|
|
texts[output.index].append(output.text)
|
|
token_ids[output.index].extend(output.token_ids)
|
|
final_result = res
|
|
if final_result is None:
|
|
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
|
final_result = final_result.with_options(
|
|
outputs=[
|
|
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
|
for output in final_result.outputs
|
|
]
|
|
)
|
|
|
|
choices = []
|
|
prompt_token_ids = final_result.prompt_token_ids
|
|
prompt_logprobs = final_result.prompt_logprobs
|
|
prompt_text = final_result.prompt
|
|
for output in final_result.outputs:
|
|
logprobs = None
|
|
if request.logprobs is not None:
|
|
if not echo_without_generation:
|
|
token_ids, top_logprobs = output.token_ids, output.logprobs
|
|
if request.echo:
|
|
token_ids, top_logprobs = prompt_token_ids + token_ids, prompt_logprobs + top_logprobs
|
|
else:
|
|
token_ids, top_logprobs = prompt_token_ids, prompt_logprobs
|
|
logprobs = create_logprobs(token_ids, top_logprobs, request.logprobs, llm=llm)
|
|
if not echo_without_generation:
|
|
output_text = output.text
|
|
if request.echo:
|
|
output_text = prompt_text + output_text
|
|
else:
|
|
output_text = prompt_text
|
|
choice_data = CompletionResponseChoice(
|
|
index=output.index, text=output_text, logprobs=logprobs, finish_reason=output.finish_reason
|
|
)
|
|
choices.append(choice_data)
|
|
|
|
num_prompt_tokens = len(final_result.prompt_token_ids)
|
|
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
|
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
|
|
response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
|
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
|
except Exception as err:
|
|
traceback.print_exc()
|
|
logger.error('Error generating completion: %s', err)
|
|
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|