types: update stubs for remaining entrypoints (#667)

* perf(type): static OpenAI types definition

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* feat: add hf types

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* types: update remaining missing stubs

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-16 04:26:13 -05:00
committed by GitHub
parent 6102a67a83
commit 9e3f0fea15
13 changed files with 231 additions and 260 deletions

View File

@@ -1,33 +1,13 @@
"""Entrypoint for all third-party apps.
Currently support OpenAI compatible API.
Each module should implement the following API:
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
"""
from __future__ import annotations
import importlib
import typing as t
from openllm_core.utils import LazyModule
if t.TYPE_CHECKING:
import bentoml
import openllm
_import_structure = {'openai': [], 'hf': [], 'cohere': []}
class IntegrationModule(t.Protocol):
def mount_to_svc(self, svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service: ...
_import_structure: dict[str, list[str]] = {'openai': [], 'hf': [], 'cohere': []}
def mount_entrypoints(svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service:
def mount_entrypoints(svc, llm):
for module_name in _import_structure:
module = t.cast(IntegrationModule, importlib.import_module(f'.{module_name}', __name__))
module = importlib.import_module(f'.{module_name}', __name__)
svc = module.mount_to_svc(svc, llm)
return svc

View File

@@ -0,0 +1,16 @@
"""Entrypoint for all third-party apps.
Currently support OpenAI, Cohere compatible API.
Each module should implement the following API:
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
"""
from bentoml import Service
from openllm_core._typing_compat import M, T
from . import cohere as cohere, hf as hf, openai as openai
from .._llm import LLM
def mount_entrypoints(svc: Service, llm: LLM[M, T]) -> Service: ...

View File

@@ -1,21 +1,14 @@
from __future__ import annotations
import functools
import inspect
import types
import typing as t
import attr
from starlette.routing import BaseRoute, Host, Mount, Route
from starlette.routing import Host, Mount, Route
from starlette.schemas import EndpointInfo, SchemaGenerator
from openllm_core._typing_compat import ParamSpec
from openllm_core.utils import first_not_none
if t.TYPE_CHECKING:
from attr import AttrsInstance
import bentoml
P = ParamSpec('P')
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
# NOTE: OpenAI schema
LIST_MODELS_SCHEMA = """\
@@ -479,7 +472,7 @@ summary: Creates a model response for the given chat conversation.
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
def add_schema_definitions(func):
append_str = _SCHEMAS.get(func.__name__.lower(), '')
if not append_str:
return func
@@ -490,8 +483,8 @@ def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
class OpenLLMSchemaGenerator(SchemaGenerator):
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
endpoints_info: list[EndpointInfo] = []
def get_endpoints(self, routes):
endpoints_info = []
for route in routes:
if isinstance(route, (Mount, Host)):
routes = route.routes or []
@@ -523,7 +516,7 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
endpoints_info.append(EndpointInfo(path, method.lower(), func))
return endpoints_info
def get_schema(self, routes: list[BaseRoute], mount_path: str | None = None) -> dict[str, t.Any]:
def get_schema(self, routes, mount_path=None):
schema = dict(self.base_schema)
schema.setdefault('paths', {})
endpoints_info = self.get_endpoints(routes)
@@ -543,13 +536,8 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
return schema
def get_generator(
title: str,
components: list[type[AttrsInstance]] | None = None,
tags: list[dict[str, t.Any]] | None = None,
inject: bool = True,
) -> OpenLLMSchemaGenerator:
base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION)
def get_generator(title, components=None, tags=None, inject=True):
base_schema = {'info': {'title': title, 'version': API_VERSION}, 'version': OPENAPI_VERSION}
if components and inject:
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
if tags is not None and tags and inject:
@@ -557,12 +545,12 @@ def get_generator(
return OpenLLMSchemaGenerator(base_schema)
def component_schema_generator(attr_cls: type[AttrsInstance], description: str | None = None) -> dict[str, t.Any]:
schema: dict[str, t.Any] = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
def component_schema_generator(attr_cls, description=None):
schema = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
schema['description'] = first_not_none(
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
)
for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc,type-var]
for field in attr.fields(attr.resolve_types(attr_cls)):
attr_type = field.type
origin_type = t.get_origin(attr_type)
args_type = t.get_args(attr_type)
@@ -593,7 +581,7 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
if 'prop_schema' not in locals():
prop_schema = {'type': schema_type}
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory):
prop_schema['default'] = field.default # type: ignore[arg-type]
prop_schema['default'] = field.default
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)):
schema['required'].append(field.name)
schema['properties'][field.name] = prop_schema
@@ -602,20 +590,15 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
return schema
class MKSchema:
def __init__(self, it: dict[str, t.Any]) -> None:
self.it = it
def asdict(self) -> dict[str, t.Any]:
return self.it
_SimpleSchema = types.new_class(
'_SimpleSchema',
(object,),
{},
lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it}),
)
def append_schemas(
svc: bentoml.Service,
generated_schema: dict[str, t.Any],
tags_order: t.Literal['prepend', 'append'] = 'prepend',
inject: bool = True,
) -> bentoml.Service:
def append_schemas(svc, generated_schema, tags_order='prepend', inject=True):
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
from bentoml._internal.service.openapi.specification import OpenAPISpecification
@@ -623,7 +606,7 @@ def append_schemas(
return svc
svc_schema = svc.openapi_spec
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
if isinstance(svc_schema, (OpenAPISpecification, _SimpleSchema)):
svc_schema = svc_schema.asdict()
if 'tags' in generated_schema:
if tags_order == 'prepend':
@@ -639,12 +622,9 @@ def append_schemas(
# HACK: mk this attribute until we have a better way to add starlette schemas.
from bentoml._internal.service import openapi
def mk_generate_spec(svc, openapi_version=OPENAPI_VERSION):
return MKSchema(svc_schema)
def _generate_spec(svc, openapi_version=OPENAPI_VERSION):
return _SimpleSchema(svc_schema)
def mk_asdict(self):
return svc_schema
openapi.generate_spec = mk_generate_spec
OpenAPISpecification.asdict = mk_asdict
openapi.generate_spec = _generate_spec
OpenAPISpecification.asdict = lambda self: svc_schema
return svc

View File

@@ -0,0 +1,28 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Type
from attr import AttrsInstance
from starlette.routing import BaseRoute
from starlette.schemas import EndpointInfo
from bentoml import Service
from openllm_core._typing_compat import ParamSpec
P = ParamSpec('P')
class OpenLLMSchemaGenerator:
base_schema: Dict[str, Any]
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: ...
def get_schema(self, routes: list[BaseRoute], mount_path: Optional[str] = ...) -> Dict[str, Any]: ...
def parse_docstring(self, func_or_method: Callable[P, Any]) -> Dict[str, Any]: ...
def add_schema_definitions(func: Callable[P, Any]) -> Callable[P, Any]: ...
def append_schemas(
svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...
) -> Service: ...
def component_schema_generator(attr_cls: Type[AttrsInstance], description: Optional[str] = ...) -> Dict[str, Any]: ...
def get_generator(
title: str,
components: Optional[List[Type[AttrsInstance]]] = ...,
tags: Optional[List[Dict[str, Any]]] = ...,
inject: bool = ...,
) -> OpenLLMSchemaGenerator: ...

View File

@@ -3,7 +3,6 @@ import functools
import json
import logging
import traceback
import typing as t
from http import HTTPStatus
import orjson
@@ -54,26 +53,16 @@ schemas = get_generator(
)
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from attr import AttrsInstance
from starlette.requests import Request
from starlette.responses import Response
import bentoml
import openllm
from openllm_core._schemas import GenerationOutput
from openllm_core._typing_compat import M, T
def jsonify_attr(obj: AttrsInstance) -> str:
def jsonify_attr(obj):
return json.dumps(converter.unstructure(obj))
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
def error_response(status_code, message):
return JSONResponse(converter.unstructure(CohereErrorResponse(text=message)), status_code=status_code.value)
async def check_model(request: CohereGenerateRequest | CohereChatRequest, model: str) -> JSONResponse | None:
async def check_model(request, model):
if request.model is None or request.model == model:
return None
return error_response(
@@ -82,7 +71,7 @@ async def check_model(request: CohereGenerateRequest | CohereChatRequest, model:
)
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
def mount_to_svc(svc, llm):
app = Starlette(
debug=True,
routes=[
@@ -90,7 +79,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
'/v1/generate', endpoint=functools.partial(cohere_generate, llm=llm), name='cohere_generate', methods=['POST']
),
Route('/v1/chat', endpoint=functools.partial(cohere_chat, llm=llm), name='cohere_chat', methods=['POST']),
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
],
)
mount_path = '/cohere'
@@ -102,7 +91,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
@add_schema_definitions
async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
async def cohere_generate(req, llm):
json_str = await req.body()
try:
request = converter.structure(orjson.loads(json_str), CohereGenerateRequest)
@@ -115,7 +104,6 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
err_check = await check_model(request, llm.llm_type)
if err_check is not None:
return err_check
request_id = gen_random_uuid('cohere-generate')
config = llm.config.with_request(request)
@@ -133,10 +121,10 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def create_stream_response_json(index: int, text: str, is_finished: bool) -> str:
def create_stream_response_json(index, text, is_finished):
return f'{jsonify_attr(StreamingText(index=index, text=text, is_finished=is_finished))}\n'
async def generate_stream_generator() -> t.AsyncGenerator[str, None]:
async def generate_stream_generator():
async for res in result_generator:
for output in res.outputs:
yield create_stream_response_json(index=output.index, text=output.text, is_finished=output.finish_reason)
@@ -146,9 +134,8 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
if request.stream:
return StreamingResponse(generate_stream_generator(), media_type='text/event-stream')
# None-streaming case
final_result: GenerationOutput | None = None
texts: list[list[str]] = [[]] * config['n']
token_ids: list[list[int]] = [[]] * config['n']
final_result = None
texts, token_ids = [[]] * config['n'], [[]] * config['n']
async for res in result_generator:
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
@@ -164,14 +151,18 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
for output in final_result.outputs
]
)
response = Generations(
id=request_id,
generations=[
Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason)
for output in final_result.outputs
],
return JSONResponse(
converter.unstructure(
Generations(
id=request_id,
generations=[
Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason)
for output in final_result.outputs
],
)
),
status_code=HTTPStatus.OK.value,
)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
@@ -192,7 +183,7 @@ def _transpile_cohere_chat_messages(request: CohereChatRequest) -> list[dict[str
@add_schema_definitions
async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
async def cohere_chat(req, llm):
json_str = await req.body()
try:
request = converter.structure(orjson.loads(json_str), CohereChatRequest)
@@ -223,9 +214,8 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
def create_stream_generation_json(index: int, text: str, is_finished: bool) -> str:
return f'{jsonify_attr(ChatStreamTextGeneration(index=index, text=text, is_finished=is_finished))}\n'
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
texts: list[str] = []
token_ids: list[int] = []
async def completion_stream_generator():
texts, token_ids = [], []
yield f'{jsonify_attr(ChatStreamStart(is_finished=False, index=0, generation_id=request_id))}\n'
it = None
@@ -237,9 +227,7 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
if it is None:
raise ValueError('No response from model.')
num_prompt_tokens = len(t.cast(t.List[int], it.prompt_token_ids))
num_response_tokens = len(token_ids)
total_tokens = num_prompt_tokens + num_response_tokens
num_prompt_tokens, num_response_tokens = len(it.prompt_token_ids), len(token_ids)
json_str = jsonify_attr(
ChatStreamEnd(
@@ -255,7 +243,7 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
token_count={
'prompt_tokens': num_prompt_tokens,
'response_tokens': num_response_tokens,
'total_tokens': total_tokens,
'total_tokens': num_prompt_tokens + num_response_tokens,
},
),
)
@@ -266,9 +254,8 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
if request.stream:
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result: GenerationOutput | None = None
texts: list[str] = []
token_ids: list[int] = []
final_result = None
texts, token_ids = [], []
async for res in result_generator:
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
@@ -280,28 +267,25 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
final_result = final_result.with_options(
outputs=[final_result.outputs[0].with_options(text=''.join(texts), token_ids=token_ids)]
)
num_prompt_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids))
num_response_tokens = len(token_ids)
total_tokens = num_prompt_tokens + num_response_tokens
response = Chat(
response_id=request_id,
message=request.message,
text=''.join(texts),
prompt=prompt,
chat_history=request.chat_history,
token_count={
'prompt_tokens': num_prompt_tokens,
'response_tokens': num_response_tokens,
'total_tokens': total_tokens,
},
num_prompt_tokens, num_response_tokens = len(final_result.prompt_token_ids), len(token_ids)
return JSONResponse(
converter.unstructure(
Chat(
response_id=request_id,
message=request.message,
text=''.join(texts),
prompt=prompt,
chat_history=request.chat_history,
token_count={
'prompt_tokens': num_prompt_tokens,
'response_tokens': num_response_tokens,
'total_tokens': num_prompt_tokens + num_response_tokens,
},
)
),
status_code=HTTPStatus.OK.value,
)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def openapi_schema(req: Request) -> Response:
return schemas.OpenAPIResponse(req)

View File

@@ -0,0 +1,21 @@
from http import HTTPStatus
from typing import Optional, Union
from attr import AttrsInstance
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from bentoml import Service
from openllm_core._typing_compat import M, T
from .._llm import LLM
from ..protocol.cohere import CohereChatRequest, CohereGenerateRequest
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
def jsonify_attr(obj: AttrsInstance) -> str: ...
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
async def check_model(
request: Union[CohereGenerateRequest, CohereChatRequest], model: str
) -> Optional[JSONResponse]: ...
async def cohere_generate(req: Request, llm: LLM[M, T]) -> Response: ...
async def cohere_chat(req: Request, llm: LLM[M, T]) -> Response: ...

View File

@@ -1,8 +1,5 @@
from __future__ import annotations
import functools
import logging
import typing as t
from enum import Enum
from http import HTTPStatus
import orjson
@@ -28,23 +25,14 @@ schemas = get_generator(
)
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from peft.config import PeftConfig
from starlette.requests import Request
from starlette.responses import Response
import bentoml
import openllm
from openllm_core._typing_compat import M, T
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
def mount_to_svc(svc, llm):
app = Starlette(
debug=True,
routes=[
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
Route('/adapters', endpoint=functools.partial(hf_adapters, llm=llm), name='adapters', methods=['GET']),
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
],
)
mount_path = '/hf'
@@ -52,7 +40,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
def error_response(status_code, message):
return JSONResponse(
converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)),
status_code=status_code.value,
@@ -60,7 +48,7 @@ def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
@add_schema_definitions
async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
async def hf_agent(req, llm):
json_str = await req.body()
try:
request = converter.structure(orjson.loads(json_str), AgentRequest)
@@ -81,17 +69,13 @@ async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
@add_schema_definitions
def hf_adapters(req: Request, llm: openllm.LLM[M, T]) -> Response:
def hf_adapters(req, llm):
if not llm.has_adapters:
return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
return JSONResponse(
{
adapter_tuple[1]: {'adapter_name': k, 'adapter_type': t.cast(Enum, adapter_tuple[0].peft_type).value}
for k, adapter_tuple in t.cast(t.Dict[str, t.Tuple['PeftConfig', str]], dict(*llm.adapter_map.values())).items()
adapter_tuple[1]: {'adapter_name': k, 'adapter_type': adapter_tuple[0].peft_type.value}
for k, adapter_tuple in dict(*llm.adapter_map.values()).items()
},
status_code=HTTPStatus.OK.value,
)
def openapi_schema(req: Request) -> Response:
return schemas.OpenAPIResponse(req)

View File

@@ -0,0 +1,14 @@
from http import HTTPStatus
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from bentoml import Service
from openllm_core._typing_compat import M, T
from .._llm import LLM
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
async def hf_agent(req: Request, llm: LLM[M, T]) -> Response: ...
def hf_adapters(req: Request, llm: LLM[M, T]) -> Response: ...

View File

@@ -1,4 +1,3 @@
from __future__ import annotations
import functools
import logging
import time
@@ -57,22 +56,12 @@ schemas = get_generator(
)
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from attr import AttrsInstance
from starlette.requests import Request
from starlette.responses import Response
import bentoml
import openllm
from openllm_core._schemas import GenerationOutput
from openllm_core._typing_compat import M, T
def jsonify_attr(obj: AttrsInstance) -> str:
def jsonify_attr(obj):
return orjson.dumps(converter.unstructure(obj)).decode()
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
def error_response(status_code, message):
return JSONResponse(
{
'error': converter.unstructure(
@@ -83,7 +72,7 @@ def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
)
async def check_model(request: CompletionRequest | ChatCompletionRequest, model: str) -> JSONResponse | None:
async def check_model(request, model):
if request.model == model:
return None
return error_response(
@@ -92,9 +81,7 @@ async def check_model(request: CompletionRequest | ChatCompletionRequest, model:
)
def create_logprobs(
token_ids: list[int], id_logprobs: list[dict[int, float]], initial_text_offset: int = 0, *, llm: openllm.LLM[M, T]
) -> LogProbs:
def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm):
# Create OpenAI-style logprobs.
logprobs = LogProbs()
last_token_len = 0
@@ -112,24 +99,23 @@ def create_logprobs(
return logprobs
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
def mount_to_svc(svc, llm):
app = Starlette(
debug=True,
routes=[
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
Route('/completions', functools.partial(completions, llm=llm), methods=['POST']),
Route('/chat/completions', functools.partial(chat_completions, llm=llm), methods=['POST']),
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
],
)
mount_path = '/v1'
svc.mount_asgi_app(app, path=mount_path)
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path))
svc.mount_asgi_app(app, path='/v1')
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path='/v1'))
# GET /v1/models
@add_schema_definitions
def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
def list_models(_, llm):
return JSONResponse(
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
)
@@ -137,7 +123,7 @@ def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
# POST /v1/chat/completions
@add_schema_definitions
async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
async def chat_completions(req, llm):
# TODO: Check for length based on model context_length
json_str = await req.body()
try:
@@ -166,7 +152,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def create_stream_response_json(index: int, text: str, finish_reason: str | None = None) -> str:
def create_stream_response_json(index, text, finish_reason=None):
return jsonify_attr(
ChatCompletionStreamResponse(
id=request_id,
@@ -178,7 +164,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
)
)
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
async def completion_stream_generator():
# first chunk with role
for i in range(config['n']):
yield f"data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role='assistant'), finish_reason=None)], model=model_name))}\n\n"
@@ -195,9 +181,8 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
if request.stream:
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result: GenerationOutput | None = None
texts: list[list[str]] = [[]] * config['n']
token_ids: list[list[int]] = [[]] * config['n']
final_result = None
texts, token_ids = [[]] * config['n'], [[]] * config['n']
async for res in result_generator:
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
@@ -221,15 +206,9 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
)
for output in final_result.outputs
]
num_prompt_tokens, num_generated_tokens = (
len(t.cast(t.List[int], final_result.prompt_token_ids)),
sum(len(output.token_ids) for output in final_result.outputs),
)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
num_prompt_tokens = len(final_result.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
response = ChatCompletionResponse(
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
)
@@ -254,7 +233,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
# POST /v1/completions
@add_schema_definitions
async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
async def completions(req, llm):
# TODO: Check for length based on model context_length
json_str = await req.body()
try:
@@ -300,9 +279,7 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
# TODO: support use_beam_search
stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of'])
def create_stream_response_json(
index: int, text: str, logprobs: LogProbs | None = None, finish_reason: str | None = None
) -> str:
def create_stream_response_json(index, text, logprobs=None, finish_reason=None):
return jsonify_attr(
CompletionStreamResponse(
id=request_id,
@@ -314,16 +291,14 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
)
)
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
async def completion_stream_generator():
previous_num_tokens = [0] * config['n']
async for res in result_generator:
for output in res.outputs:
i = output.index
if request.logprobs is not None:
logprobs = create_logprobs(
token_ids=output.token_ids,
id_logprobs=t.cast(SampleLogprobs, output.logprobs)[previous_num_tokens[i] :],
llm=llm,
token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i] :], llm=llm
)
else:
logprobs = None
@@ -339,9 +314,8 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
if stream:
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result: GenerationOutput | None = None
texts: list[list[str]] = [[]] * config['n']
token_ids: list[list[int]] = [[]] * config['n']
final_result = None
texts, token_ids = [[]] * config['n'], [[]] * config['n']
async for res in result_generator:
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
@@ -358,7 +332,7 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
]
)
choices: list[CompletionResponseChoice] = []
choices = []
for output in final_result.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(
@@ -371,15 +345,9 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
)
choices.append(choice_data)
num_prompt_tokens = len(
t.cast(t.List[int], final_result.prompt_token_ids)
) # XXX: We will always return prompt_token_ids, so this won't be None
num_prompt_tokens = len(final_result.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
if request.stream:
@@ -398,7 +366,3 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def openapi_schema(req: Request) -> Response:
return schemas.OpenAPIResponse(req)

View File

@@ -0,0 +1,25 @@
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from attr import AttrsInstance
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from bentoml import Service
from openllm_core._typing_compat import M, T
from .._llm import LLM
from ..protocol.openai import ChatCompletionRequest, CompletionRequest, LogProbs
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
def jsonify_attr(obj: AttrsInstance) -> str: ...
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
async def check_model(
request: Union[CompletionRequest, ChatCompletionRequest], model: str
) -> Optional[JSONResponse]: ...
def create_logprobs(
token_ids: List[int], id_logprobs: List[Dict[int, float]], initial_text_offset: int = ..., *, llm: LLM[M, T]
) -> LogProbs: ...
def list_models(req: Request, llm: LLM[M, T]) -> Response: ...
async def chat_completions(req: Request, llm: LLM[M, T]) -> Response: ...
async def completions(req: Request, llm: LLM[M, T]) -> Response: ...

View File

@@ -8,7 +8,7 @@ import click
import click_option_group as cog
import inflection
from bentoml_cli.utils import BentoMLCommandGroup
from click import ClickException, shell_completion as sc
from click import shell_completion as sc
import bentoml
import openllm
@@ -22,7 +22,7 @@ from openllm_core._typing_compat import (
ParamSpec,
get_literal_args,
)
from openllm_core.utils import DEBUG
from openllm_core.utils import DEBUG, resolve_user_filepath
class _OpenLLM_GenericInternalConfig(LLMConfig):
@@ -125,12 +125,11 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
# try to resolve the full path if users pass in relative,
# currently only support one level of resolve path with current directory
try:
adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
adapter_id = resolve_user_filepath(adapter_id, os.getcwd())
except FileNotFoundError:
pass
if len(adapter_name) == 0:
raise ClickException(f'Adapter name is required for {adapter_id}')
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0]
name = adapter_name[0] if len(adapter_name) > 0 else 'default'
ctx.params[_adapter_mapping_key][adapter_id] = name
return None
@@ -171,32 +170,7 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC
help='Assign GPU devices (if available)',
show_envvar=True,
),
cog.optgroup.group(
'Fine-tuning related options',
help="""\
Note that the argument `--adapter-id` can accept the following format:
- `--adapter-id /path/to/adapter` (local adapter)
- `--adapter-id remote/adapter` (remote adapter from HuggingFace Hub)
- `--adapter-id remote/adapter:eng_lora` (two previous adapter options with the given adapter_name)
```bash
$ openllm start opt --adapter-id /path/to/adapter_dir --adapter-id remote/adapter:eng_lora
```
""",
),
cog.optgroup.option(
'--adapter-id',
default=None,
help='Optional name or path for given LoRA adapter',
multiple=True,
callback=_id_callback,
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
),
adapter_id_option(factory=cog.optgroup),
click.option('--return-process', is_flag=True, default=False, help='Internal use only.', hidden=True),
)
return composed(fn)
@@ -285,6 +259,17 @@ cli_option = functools.partial(_click_factory_type, attr='option')
cli_argument = functools.partial(_click_factory_type, attr='argument')
def adapter_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--adapter-id',
default=None,
help='Optional name or path for given LoRA adapter',
multiple=True,
callback=_id_callback,
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
)
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--cors/--no-cors',

View File

@@ -691,7 +691,7 @@ def pretty_print(line: str):
elif 'ERROR' in line:
caller = termui.error
else:
caller = caller = functools.partial(termui.echo, fg=None)
caller = functools.partial(termui.echo, fg=None)
caller(line.strip())
@@ -1026,6 +1026,7 @@ def build_command(
'OPENLLM_BACKEND': llm.__llm_backend__,
'OPENLLM_SERIALIZATION': llm._serialisation,
'OPENLLM_MODEL_ID': llm.model_id,
'OPENLLM_ADAPTER_MAP': orjson.dumps(None).decode(),
}
)
if llm.quantise:
@@ -1035,13 +1036,8 @@ def build_command(
if prompt_template:
os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
try:
assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything
# FIX: This is a patch for _service_vars injection
if 'OPENLLM_ADAPTER_MAP' not in os.environ:
os.environ['OPENLLM_ADAPTER_MAP'] = orjson.dumps(None).decode()
labels = dict(llm.identifying_params)
labels.update({'_type': llm.llm_type, '_framework': llm.__llm_backend__})
@@ -1190,14 +1186,8 @@ class ModelItem(t.TypedDict):
@cli.command()
@click.option(
'--show-available',
is_flag=True,
default=True,
hidden=True,
help="Show available models in local store (mutually exclusive with '-o porcelain').",
)
def models_command(show_available: bool) -> dict[t.LiteralString, ModelItem]:
@click.option('--show-available', is_flag=True, default=True, hidden=True)
def models_command(**_: t.Any) -> dict[t.LiteralString, ModelItem]:
"""List all supported models.
\b