infra: using ruff formatter (#594)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-09 12:44:05 -05:00
committed by GitHub
parent 021fd453b9
commit ac377fe490
102 changed files with 5577 additions and 2540 deletions

View File

@@ -6,6 +6,7 @@ Each module should implement the following API:
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
"""
from __future__ import annotations
import typing as t
@@ -14,16 +15,21 @@ from openllm_core.utils import LazyModule
from . import hf as hf
from . import openai as openai
if t.TYPE_CHECKING:
import bentoml
import openllm
_import_structure: dict[str, list[str]] = {'openai': [], 'hf': []}
def mount_entrypoints(svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service:
return openai.mount_to_svc(hf.mount_to_svc(svc, llm), llm)
__lazy = LazyModule(__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints})
__lazy = LazyModule(
__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints}
)
__all__ = __lazy.__all__
__dir__ = __lazy.__dir__
__getattr__ = __lazy.__getattr__

View File

@@ -15,6 +15,7 @@ from starlette.schemas import SchemaGenerator
from openllm_core._typing_compat import ParamSpec
from openllm_core.utils import first_not_none
if t.TYPE_CHECKING:
from attr import AttrsInstance
@@ -23,7 +24,7 @@ if t.TYPE_CHECKING:
P = ParamSpec('P')
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
# NOTE: OpenAI schema
LIST_MODEL_SCHEMA = '''\
LIST_MODEL_SCHEMA = """\
---
consumes:
- application/json
@@ -53,8 +54,8 @@ responses:
owned_by: 'na'
schema:
$ref: '#/components/schemas/ModelList'
'''
CHAT_COMPLETION_SCHEMA = '''\
"""
CHAT_COMPLETION_SCHEMA = """\
---
consumes:
- application/json
@@ -191,8 +192,8 @@ responses:
}
}
description: Bad Request
'''
COMPLETION_SCHEMA = '''\
"""
COMPLETION_SCHEMA = """\
---
consumes:
- application/json
@@ -344,8 +345,8 @@ responses:
}
}
description: Bad Request
'''
HF_AGENT_SCHEMA = '''\
"""
HF_AGENT_SCHEMA = """\
---
consumes:
- application/json
@@ -389,8 +390,8 @@ responses:
schema:
$ref: '#/components/schemas/HFErrorResponse'
description: Not Found
'''
HF_ADAPTERS_SCHEMA = '''\
"""
HF_ADAPTERS_SCHEMA = """\
---
consumes:
- application/json
@@ -420,16 +421,19 @@ responses:
schema:
$ref: '#/components/schemas/HFErrorResponse'
description: Not Found
'''
"""
def add_schema_definitions(append_str: str) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]:
def docstring_decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
if func.__doc__ is None: func.__doc__ = ''
if func.__doc__ is None:
func.__doc__ = ''
func.__doc__ = func.__doc__.strip() + '\n\n' + append_str.strip()
return func
return docstring_decorator
class OpenLLMSchemaGenerator(SchemaGenerator):
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
endpoints_info: list[EndpointInfo] = []
@@ -437,20 +441,29 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
if isinstance(route, (Mount, Host)):
routes = route.routes or []
path = self._remove_converter(route.path) if isinstance(route, Mount) else ''
sub_endpoints = [EndpointInfo(path=f'{path}{sub_endpoint.path}', http_method=sub_endpoint.http_method, func=sub_endpoint.func) for sub_endpoint in self.get_endpoints(routes)]
sub_endpoints = [
EndpointInfo(path=f'{path}{sub_endpoint.path}', http_method=sub_endpoint.http_method, func=sub_endpoint.func)
for sub_endpoint in self.get_endpoints(routes)
]
endpoints_info.extend(sub_endpoints)
elif not isinstance(route, Route) or not route.include_in_schema:
continue
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint) or isinstance(route.endpoint, functools.partial):
elif (
inspect.isfunction(route.endpoint)
or inspect.ismethod(route.endpoint)
or isinstance(route.endpoint, functools.partial)
):
endpoint = route.endpoint.func if isinstance(route.endpoint, functools.partial) else route.endpoint
path = self._remove_converter(route.path)
for method in route.methods or ['GET']:
if method == 'HEAD': continue
if method == 'HEAD':
continue
endpoints_info.append(EndpointInfo(path, method.lower(), endpoint))
else:
path = self._remove_converter(route.path)
for method in ['get', 'post', 'put', 'patch', 'delete', 'options']:
if not hasattr(route.endpoint, method): continue
if not hasattr(route.endpoint, method):
continue
func = getattr(route.endpoint, method)
endpoints_info.append(EndpointInfo(path, method.lower(), func))
return endpoints_info
@@ -459,37 +472,52 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
schema = dict(self.base_schema)
schema.setdefault('paths', {})
endpoints_info = self.get_endpoints(routes)
if mount_path: mount_path = f'/{mount_path}' if not mount_path.startswith('/') else mount_path
if mount_path:
mount_path = f'/{mount_path}' if not mount_path.startswith('/') else mount_path
for endpoint in endpoints_info:
parsed = self.parse_docstring(endpoint.func)
if not parsed: continue
if not parsed:
continue
path = endpoint.path if mount_path is None else mount_path + endpoint.path
if path not in schema['paths']: schema['paths'][path] = {}
if path not in schema['paths']:
schema['paths'][path] = {}
schema['paths'][path][endpoint.http_method] = parsed
return schema
def get_generator(title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None) -> OpenLLMSchemaGenerator:
def get_generator(
title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None
) -> OpenLLMSchemaGenerator:
base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION)
if components: base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
if tags is not None and tags: base_schema['tags'] = tags
if components:
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
if tags is not None and tags:
base_schema['tags'] = tags
return OpenLLMSchemaGenerator(base_schema)
def component_schema_generator(attr_cls: type[AttrsInstance], description: str | None = None) -> dict[str, t.Any]:
schema: dict[str, t.Any] = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
schema['description'] = first_not_none(getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}')
schema['description'] = first_not_none(
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
)
for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc,type-var]
attr_type = field.type
origin_type = t.get_origin(attr_type)
args_type = t.get_args(attr_type)
# Map Python types to OpenAPI schema types
if attr_type == str: schema_type = 'string'
elif attr_type == int: schema_type = 'integer'
elif attr_type == float: schema_type = 'number'
elif attr_type == bool: schema_type = 'boolean'
if attr_type == str:
schema_type = 'string'
elif attr_type == int:
schema_type = 'integer'
elif attr_type == float:
schema_type = 'number'
elif attr_type == bool:
schema_type = 'boolean'
elif origin_type is list or origin_type is tuple:
schema_type = 'array'
elif origin_type is dict:
@@ -504,14 +532,18 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
else:
schema_type = 'string'
if 'prop_schema' not in locals(): prop_schema = {'type': schema_type}
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory): prop_schema['default'] = field.default # type: ignore[arg-type]
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)): schema['required'].append(field.name)
if 'prop_schema' not in locals():
prop_schema = {'type': schema_type}
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory):
prop_schema['default'] = field.default # type: ignore[arg-type]
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)):
schema['required'].append(field.name)
schema['properties'][field.name] = prop_schema
locals().pop('prop_schema', None)
return schema
class MKSchema:
def __init__(self, it: dict[str, t.Any]) -> None:
self.it = it
@@ -519,19 +551,30 @@ class MKSchema:
def asdict(self) -> dict[str, t.Any]:
return self.it
def append_schemas(svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend') -> bentoml.Service:
def append_schemas(
svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend'
) -> bentoml.Service:
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
from bentoml._internal.service.openapi.specification import OpenAPISpecification
svc_schema: t.Any = svc.openapi_spec
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)): svc_schema = svc_schema.asdict()
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
svc_schema = svc_schema.asdict()
if 'tags' in generated_schema:
if tags_order == 'prepend': svc_schema['tags'] = generated_schema['tags'] + svc_schema['tags']
elif tags_order == 'append': svc_schema['tags'].extend(generated_schema['tags'])
else: raise ValueError(f'Invalid tags_order: {tags_order}')
if 'components' in generated_schema: svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
if tags_order == 'prepend':
svc_schema['tags'] = generated_schema['tags'] + svc_schema['tags']
elif tags_order == 'append':
svc_schema['tags'].extend(generated_schema['tags'])
else:
raise ValueError(f'Invalid tags_order: {tags_order}')
if 'components' in generated_schema:
svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
svc_schema['paths'].update(generated_schema['paths'])
from bentoml._internal.service import openapi # HACK: mk this attribute until we have a better way to add starlette schemas.
from bentoml._internal.service import (
openapi, # HACK: mk this attribute until we have a better way to add starlette schemas.
)
# yapf: disable
def mk_generate_spec(svc:bentoml.Service,openapi_version:str=OPENAPI_VERSION)->MKSchema:return MKSchema(svc_schema)

View File

@@ -23,17 +23,21 @@ from ..protocol.hf import AgentRequest
from ..protocol.hf import AgentResponse
from ..protocol.hf import HFErrorResponse
schemas = get_generator('hf',
components=[AgentRequest, AgentResponse, HFErrorResponse],
tags=[{
'name': 'HF',
'description': 'HF integration, including Agent and others schema endpoints.',
'externalDocs': 'https://huggingface.co/docs/transformers/main_classes/agent'
}])
schemas = get_generator(
'hf',
components=[AgentRequest, AgentResponse, HFErrorResponse],
tags=[
{
'name': 'HF',
'description': 'HF integration, including Agent and others schema endpoints.',
'externalDocs': 'https://huggingface.co/docs/transformers/main_classes/agent',
}
],
)
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from peft.config import PeftConfig
from starlette.requests import Request
from starlette.responses import Response
@@ -44,20 +48,28 @@ if t.TYPE_CHECKING:
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
app = Starlette(debug=True,
routes=[
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
Route('/adapters', endpoint=functools.partial(adapters_map, llm=llm), name='adapters', methods=['GET']),
Route('/schema', endpoint=openapi_schema, include_in_schema=False)
])
app = Starlette(
debug=True,
routes=[
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
Route('/adapters', endpoint=functools.partial(adapters_map, llm=llm), name='adapters', methods=['GET']),
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
],
)
mount_path = '/hf'
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
svc.mount_asgi_app(app, path=mount_path)
return append_schemas(svc, generated_schema, tags_order='append')
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
return JSONResponse(converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)), status_code=status_code.value)
return JSONResponse(
converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)),
status_code=status_code.value,
)
@add_schema_definitions(HF_AGENT_SCHEMA)
async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
@@ -72,22 +84,26 @@ async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
stop = request.parameters.pop('stop', ['\n'])
try:
result = await llm.generate(request.inputs, stop=stop, **request.parameters)
return JSONResponse(converter.unstructure([AgentResponse(generated_text=result.outputs[0].text)]), status_code=HTTPStatus.OK.value)
return JSONResponse(
converter.unstructure([AgentResponse(generated_text=result.outputs[0].text)]), status_code=HTTPStatus.OK.value
)
except Exception as err:
logger.error('Error while generating: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Error while generating (Check server log).')
@add_schema_definitions(HF_ADAPTERS_SCHEMA)
def adapters_map(req: Request, llm: openllm.LLM[M, T]) -> Response:
if not llm.has_adapters: return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
if not llm.has_adapters:
return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
return JSONResponse(
{
adapter_tuple[1]: {
'adapter_name': k,
'adapter_type': t.cast(Enum, adapter_tuple[0].peft_type).value
} for k, adapter_tuple in t.cast(t.Dict[str, t.Tuple['PeftConfig', str]], dict(*llm.adapter_map.values())).items()
},
status_code=HTTPStatus.OK.value)
{
adapter_tuple[1]: {'adapter_name': k, 'adapter_type': t.cast(Enum, adapter_tuple[0].peft_type).value}
for k, adapter_tuple in t.cast(t.Dict[str, t.Tuple['PeftConfig', str]], dict(*llm.adapter_map.values())).items()
},
status_code=HTTPStatus.OK.value,
)
def openapi_schema(req: Request) -> Response:
return schemas.OpenAPIResponse(req)

View File

@@ -42,14 +42,27 @@ from ..protocol.openai import ModelCard
from ..protocol.openai import ModelList
from ..protocol.openai import UsageInfo
schemas = get_generator(
'openai',
components=[ErrorResponse, ModelList, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionStreamResponse, CompletionRequest, CompletionResponse, CompletionStreamResponse],
tags=[{
'name': 'OpenAI',
'description': 'OpenAI Compatible API support',
'externalDocs': 'https://platform.openai.com/docs/api-reference/completions/object'
}])
'openai',
components=[
ErrorResponse,
ModelList,
ChatCompletionResponse,
ChatCompletionRequest,
ChatCompletionStreamResponse,
CompletionRequest,
CompletionResponse,
CompletionStreamResponse,
],
tags=[
{
'name': 'OpenAI',
'description': 'OpenAI Compatible API support',
'externalDocs': 'https://platform.openai.com/docs/api-reference/completions/object',
}
],
)
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
@@ -64,20 +77,34 @@ if t.TYPE_CHECKING:
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
def jsonify_attr(obj: AttrsInstance) -> str:
return orjson.dumps(converter.unstructure(obj)).decode()
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
return JSONResponse({'error': converter.unstructure(ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value)))}, status_code=status_code.value)
async def check_model(request: CompletionRequest | ChatCompletionRequest, model: str) -> JSONResponse | None:
if request.model == model: return None
return error_response(
HTTPStatus.NOT_FOUND,
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request."
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
return JSONResponse(
{
'error': converter.unstructure(
ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value))
)
},
status_code=status_code.value,
)
def create_logprobs(token_ids: list[int], id_logprobs: list[dict[int, float]], initial_text_offset: int = 0, *, llm: openllm.LLM[M, T]) -> LogProbs:
async def check_model(request: CompletionRequest | ChatCompletionRequest, model: str) -> JSONResponse | None:
if request.model == model:
return None
return error_response(
HTTPStatus.NOT_FOUND,
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
)
def create_logprobs(
token_ids: list[int], id_logprobs: list[dict[int, float]], initial_text_offset: int = 0, *, llm: openllm.LLM[M, T]
) -> LogProbs:
# Create OpenAI-style logprobs.
logprobs = LogProbs()
last_token_len = 0
@@ -94,22 +121,29 @@ def create_logprobs(token_ids: list[int], id_logprobs: list[dict[int, float]], i
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in id_logprob.items()})
return logprobs
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
app = Starlette(debug=True,
routes=[
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
Route('/completions', functools.partial(create_completions, llm=llm), methods=['POST']),
Route('/chat/completions', functools.partial(create_chat_completions, llm=llm), methods=['POST'])
])
app = Starlette(
debug=True,
routes=[
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
Route('/completions', functools.partial(create_completions, llm=llm), methods=['POST']),
Route('/chat/completions', functools.partial(create_chat_completions, llm=llm), methods=['POST']),
],
)
mount_path = '/v1'
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
svc.mount_asgi_app(app, path=mount_path)
return append_schemas(svc, generated_schema)
# GET /v1/models
@add_schema_definitions(LIST_MODEL_SCHEMA)
def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value)
return JSONResponse(
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
)
# POST /v1/chat/completions
@add_schema_definitions(CHAT_COMPLETION_SCHEMA)
@@ -124,11 +158,14 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
logger.debug('Received chat completion request: %s', request)
err_check = await check_model(request, llm.llm_type)
if err_check is not None: return err_check
if err_check is not None:
return err_check
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
created_time = int(time.monotonic())
prompt = llm.tokenizer.apply_chat_template(request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt'])
prompt = llm.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
)
logger.debug('Prompt: %r', prompt)
config = llm.config.with_openai_request(request)
@@ -141,10 +178,15 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
def create_stream_response_json(index: int, text: str, finish_reason: str | None = None) -> str:
return jsonify_attr(
ChatCompletionStreamResponse(id=request_id,
created=created_time,
model=model_name,
choices=[ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)]))
ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
],
)
)
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
# first chunk with role
@@ -160,25 +202,47 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
try:
# Streaming case
if request.stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
if request.stream:
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result: GenerationOutput | None = None
texts: list[list[str]] = [[]] * config['n']
token_ids: list[list[int]] = [[]] * config['n']
async for res in result_generator:
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
for output in res.outputs:
texts[output.index].append(output.text)
token_ids[output.index].extend(output.token_ids)
final_result = res
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
final_result = final_result.with_options(outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs])
if final_result is None:
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
final_result = final_result.with_options(
outputs=[
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
for output in final_result.outputs
]
)
choices = [
ChatCompletionResponseChoice(index=output.index, message=ChatMessage(role='assistant', content=output.text), finish_reason=output.finish_reason) for output in final_result.outputs
ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role='assistant', content=output.text),
finish_reason=output.finish_reason,
)
for output in final_result.outputs
]
num_prompt_tokens, num_generated_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids)), sum(len(output.token_ids) for output in final_result.outputs)
usage = UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens)
response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
num_prompt_tokens, num_generated_tokens = (
len(t.cast(t.List[int], final_result.prompt_token_ids)),
sum(len(output.token_ids) for output in final_result.outputs),
)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
)
if request.stream: # type: ignore[unreachable]
# When user requests streaming but we don't stream, we still need to
@@ -187,7 +251,9 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
yield f'data: {jsonify_attr(response)}\n\n'
yield 'data: [DONE]\n\n'
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
return StreamingResponse(
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err:
@@ -195,6 +261,7 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
# POST /v1/completions
@add_schema_definitions(COMPLETION_SCHEMA)
async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
@@ -208,18 +275,25 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
logger.debug('Received legacy completion request: %s', request)
err_check = await check_model(request, llm.llm_type)
if err_check is not None: return err_check
if err_check is not None:
return err_check
if request.echo: return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
if request.echo:
return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
if request.suffix is not None:
return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
if request.logit_bias is not None and len(request.logit_bias) > 0:
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
if not request.prompt: return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
if not request.prompt:
return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
prompt = request.prompt
# TODO: Support multiple prompts
if request.logprobs is not None and llm.__llm_backend__ == 'pt': # TODO: support logprobs generation for PyTorch
return error_response(HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`.")
return error_response(
HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`."
)
model_name, request_id = request.model, gen_random_uuid('cmpl')
created_time = int(time.monotonic())
@@ -236,12 +310,19 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
# TODO: support use_beam_search
stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of'])
def create_stream_response_json(index: int, text: str, logprobs: LogProbs | None = None, finish_reason: str | None = None) -> str:
def create_stream_response_json(
index: int, text: str, logprobs: LogProbs | None = None, finish_reason: str | None = None
) -> str:
return jsonify_attr(
CompletionStreamResponse(id=request_id,
created=created_time,
model=model_name,
choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)]))
CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)
],
)
)
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
previous_num_tokens = [0] * config['n']
@@ -249,7 +330,11 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
for output in res.outputs:
i = output.index
if request.logprobs is not None:
logprobs = create_logprobs(token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs)[previous_num_tokens[i]:], llm=llm)
logprobs = create_logprobs(
token_ids=output.token_ids,
id_logprobs=t.cast(SampleLogprobs, output.logprobs)[previous_num_tokens[i] :],
llm=llm,
)
else:
logprobs = None
previous_num_tokens[i] += len(output.token_ids)
@@ -261,32 +346,50 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
try:
# Streaming case
if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
if stream:
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result: GenerationOutput | None = None
texts: list[list[str]] = [[]] * config['n']
token_ids: list[list[int]] = [[]] * config['n']
async for res in result_generator:
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
for output in res.outputs:
texts[output.index].append(output.text)
token_ids[output.index].extend(output.token_ids)
final_result = res
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
final_result = final_result.with_options(outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs])
if final_result is None:
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
final_result = final_result.with_options(
outputs=[
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
for output in final_result.outputs
]
)
choices: list[CompletionResponseChoice] = []
for output in final_result.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs), llm=llm)
logprobs = create_logprobs(
token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs), llm=llm
)
else:
logprobs = None
choice_data = CompletionResponseChoice(index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)
choice_data = CompletionResponseChoice(
index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason
)
choices.append(choice_data)
num_prompt_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids)) # XXX: We will always return prompt_token_ids, so this won't be None
num_prompt_tokens = len(
t.cast(t.List[int], final_result.prompt_token_ids)
) # XXX: We will always return prompt_token_ids, so this won't be None
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
usage = UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
if request.stream:
@@ -296,7 +399,9 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
yield f'data: {jsonify_attr(response)}\n\n'
yield 'data: [DONE]\n\n'
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
return StreamingResponse(
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err: