mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-19 15:18:12 -05:00
chore(openapi): unify inject param (#645)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -611,11 +611,17 @@ class MKSchema:
|
||||
|
||||
|
||||
def append_schemas(
|
||||
svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend'
|
||||
svc: bentoml.Service,
|
||||
generated_schema: dict[str, t.Any],
|
||||
tags_order: t.Literal['prepend', 'append'] = 'prepend',
|
||||
inject: bool = True,
|
||||
) -> bentoml.Service:
|
||||
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
|
||||
from bentoml._internal.service.openapi.specification import OpenAPISpecification
|
||||
|
||||
if not inject:
|
||||
return svc
|
||||
|
||||
svc_schema = svc.openapi_spec
|
||||
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
|
||||
svc_schema = svc_schema.asdict()
|
||||
|
||||
@@ -13,9 +13,7 @@ from starlette.routing import Route
|
||||
|
||||
from openllm_core.utils import converter, gen_random_uuid
|
||||
|
||||
from ._openapi import append_schemas, get_generator
|
||||
|
||||
# from ._openapi import add_schema_definitions
|
||||
from ._openapi import add_schema_definitions, append_schemas, get_generator
|
||||
from ..protocol.cohere import (
|
||||
Chat,
|
||||
ChatStreamEnd,
|
||||
@@ -89,29 +87,21 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
|
||||
debug=True,
|
||||
routes=[
|
||||
Route(
|
||||
'/v1/generate',
|
||||
endpoint=functools.partial(cohere_generate, llm=llm),
|
||||
name='cohere_generate',
|
||||
methods=['POST'],
|
||||
include_in_schema=False,
|
||||
),
|
||||
Route(
|
||||
'/v1/chat',
|
||||
endpoint=functools.partial(cohere_chat, llm=llm),
|
||||
name='cohere_chat',
|
||||
methods=['POST'],
|
||||
include_in_schema=False,
|
||||
'/v1/generate', endpoint=functools.partial(cohere_generate, llm=llm), name='cohere_generate', methods=['POST']
|
||||
),
|
||||
Route('/v1/chat', endpoint=functools.partial(cohere_chat, llm=llm), name='cohere_chat', methods=['POST']),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
|
||||
],
|
||||
)
|
||||
mount_path = '/cohere'
|
||||
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
|
||||
return append_schemas(
|
||||
svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append', inject=False
|
||||
)
|
||||
|
||||
|
||||
# @add_schema_definitions
|
||||
@add_schema_definitions
|
||||
async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
json_str = await req.body()
|
||||
try:
|
||||
@@ -201,7 +191,7 @@ def _transpile_cohere_chat_messages(request: CohereChatRequest) -> list[dict[str
|
||||
return messages
|
||||
|
||||
|
||||
# @add_schema_definitions
|
||||
@add_schema_definitions
|
||||
async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
json_str = await req.body()
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user