chore(openapi): unify inject param (#645)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-14 01:16:20 -05:00
committed by GitHub
parent b0ab8ccdf6
commit 00d6016bcb
2 changed files with 15 additions and 19 deletions

View File

@@ -611,11 +611,17 @@ class MKSchema:
def append_schemas(
svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend'
svc: bentoml.Service,
generated_schema: dict[str, t.Any],
tags_order: t.Literal['prepend', 'append'] = 'prepend',
inject: bool = True,
) -> bentoml.Service:
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
from bentoml._internal.service.openapi.specification import OpenAPISpecification
if not inject:
return svc
svc_schema = svc.openapi_spec
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
svc_schema = svc_schema.asdict()

View File

@@ -13,9 +13,7 @@ from starlette.routing import Route
from openllm_core.utils import converter, gen_random_uuid
from ._openapi import append_schemas, get_generator
# from ._openapi import add_schema_definitions
from ._openapi import add_schema_definitions, append_schemas, get_generator
from ..protocol.cohere import (
Chat,
ChatStreamEnd,
@@ -89,29 +87,21 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
debug=True,
routes=[
Route(
'/v1/generate',
endpoint=functools.partial(cohere_generate, llm=llm),
name='cohere_generate',
methods=['POST'],
include_in_schema=False,
),
Route(
'/v1/chat',
endpoint=functools.partial(cohere_chat, llm=llm),
name='cohere_chat',
methods=['POST'],
include_in_schema=False,
'/v1/generate', endpoint=functools.partial(cohere_generate, llm=llm), name='cohere_generate', methods=['POST']
),
Route('/v1/chat', endpoint=functools.partial(cohere_chat, llm=llm), name='cohere_chat', methods=['POST']),
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
],
)
mount_path = '/cohere'
svc.mount_asgi_app(app, path=mount_path)
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
return append_schemas(
svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append', inject=False
)
# @add_schema_definitions
@add_schema_definitions
async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
json_str = await req.body()
try:
@@ -201,7 +191,7 @@ def _transpile_cohere_chat_messages(request: CohereChatRequest) -> list[dict[str
return messages
# @add_schema_definitions
@add_schema_definitions
async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
json_str = await req.body()
try: