From 9e3f0fea152928f22e9baccbc794bdc3c6ddc542 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Thu, 16 Nov 2023 04:26:13 -0500 Subject: [PATCH] types: update stubs for remaining entrypoints (#667) * perf(type): static OpenAI types definition Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * feat: add hf types Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * types: update remaining missing stubs Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- mypy.ini | 2 +- .../src/openllm/entrypoints/__init__.py | 26 +---- .../src/openllm/entrypoints/__init__.pyi | 16 +++ .../src/openllm/entrypoints/_openapi.py | 68 ++++------- .../src/openllm/entrypoints/_openapi.pyi | 28 +++++ .../src/openllm/entrypoints/cohere.py | 106 ++++++++---------- .../src/openllm/entrypoints/cohere.pyi | 21 ++++ openllm-python/src/openllm/entrypoints/hf.py | 30 ++--- openllm-python/src/openllm/entrypoints/hf.pyi | 14 +++ .../src/openllm/entrypoints/openai.py | 88 +++++---------- .../src/openllm/entrypoints/openai.pyi | 25 +++++ openllm-python/src/openllm_cli/_factory.py | 49 +++----- openllm-python/src/openllm_cli/entrypoint.py | 18 +-- 13 files changed, 231 insertions(+), 260 deletions(-) create mode 100644 openllm-python/src/openllm/entrypoints/__init__.pyi create mode 100644 openllm-python/src/openllm/entrypoints/_openapi.pyi create mode 100644 openllm-python/src/openllm/entrypoints/cohere.pyi create mode 100644 openllm-python/src/openllm/entrypoints/hf.pyi create mode 100644 openllm-python/src/openllm/entrypoints/openai.pyi diff --git a/mypy.ini b/mypy.ini index 7962f935..4d5342fd 100644 --- a/mypy.ini +++ b/mypy.ini @@ -7,4 +7,4 @@ warn_unused_configs = True ignore_missing_imports = true check_untyped_defs = true warn_unreachable = true -files = openllm-python/src/openllm/bundle/__init__.pyi, openllm-python/src/openllm/serialisation/__init__.pyi, openllm-client/src/openllm_client/__init__.pyi, openllm-client/src/openllm_client/_utils.pyi, openllm-python/src/openllm/__init__.pyi, openllm-client/src/openllm_client/_typing_compat.py, openllm-core/src/openllm_core/_typing_compat.py, openllm-python/src/openllm/client.pyi, openllm-python/src/openllm/bundle/_package.pyi, openllm-python/src/openllm/_runners.pyi, openllm-python/src/openllm/_quantisation.pyi, openllm-python/src/openllm/_llm.pyi, openllm-python/src/openllm/_generation.pyi +files = openllm-python/src/openllm/bundle/__init__.pyi, openllm-python/src/openllm/serialisation/__init__.pyi, openllm-client/src/openllm_client/__init__.pyi, openllm-client/src/openllm_client/_utils.pyi, openllm-python/src/openllm/__init__.pyi, openllm-client/src/openllm_client/_typing_compat.py, openllm-core/src/openllm_core/_typing_compat.py, openllm-python/src/openllm/client.pyi, openllm-python/src/openllm/bundle/_package.pyi, openllm-python/src/openllm/_runners.pyi, openllm-python/src/openllm/_quantisation.pyi, openllm-python/src/openllm/_llm.pyi, openllm-python/src/openllm/_generation.pyi, openllm-python/src/openllm/entrypoints/openai.pyi, openllm-python/src/openllm/entrypoints/__init__.pyi, openllm-python/src/openllm/entrypoints/hf.pyi, openllm-python/src/openllm/entrypoints/_openapi.pyi, openllm-python/src/openllm/entrypoints/cohere.pyi diff --git a/openllm-python/src/openllm/entrypoints/__init__.py b/openllm-python/src/openllm/entrypoints/__init__.py index 5e31d9d3..b2b4e85a 100644 --- a/openllm-python/src/openllm/entrypoints/__init__.py +++ b/openllm-python/src/openllm/entrypoints/__init__.py @@ -1,33 +1,13 @@ -"""Entrypoint for all third-party apps. - -Currently support OpenAI compatible API. - -Each module should implement the following API: - -- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...` -""" - -from __future__ import annotations import importlib -import typing as t from openllm_core.utils import LazyModule -if t.TYPE_CHECKING: - import bentoml - import openllm +_import_structure = {'openai': [], 'hf': [], 'cohere': []} -class IntegrationModule(t.Protocol): - def mount_to_svc(self, svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service: ... - - -_import_structure: dict[str, list[str]] = {'openai': [], 'hf': [], 'cohere': []} - - -def mount_entrypoints(svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service: +def mount_entrypoints(svc, llm): for module_name in _import_structure: - module = t.cast(IntegrationModule, importlib.import_module(f'.{module_name}', __name__)) + module = importlib.import_module(f'.{module_name}', __name__) svc = module.mount_to_svc(svc, llm) return svc diff --git a/openllm-python/src/openllm/entrypoints/__init__.pyi b/openllm-python/src/openllm/entrypoints/__init__.pyi new file mode 100644 index 00000000..dbd38429 --- /dev/null +++ b/openllm-python/src/openllm/entrypoints/__init__.pyi @@ -0,0 +1,16 @@ +"""Entrypoint for all third-party apps. + +Currently support OpenAI, Cohere compatible API. + +Each module should implement the following API: + +- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...` +""" + +from bentoml import Service +from openllm_core._typing_compat import M, T + +from . import cohere as cohere, hf as hf, openai as openai +from .._llm import LLM + +def mount_entrypoints(svc: Service, llm: LLM[M, T]) -> Service: ... diff --git a/openllm-python/src/openllm/entrypoints/_openapi.py b/openllm-python/src/openllm/entrypoints/_openapi.py index e483090a..fc02c0b2 100644 --- a/openllm-python/src/openllm/entrypoints/_openapi.py +++ b/openllm-python/src/openllm/entrypoints/_openapi.py @@ -1,21 +1,14 @@ -from __future__ import annotations import functools import inspect +import types import typing as t import attr -from starlette.routing import BaseRoute, Host, Mount, Route +from starlette.routing import Host, Mount, Route from starlette.schemas import EndpointInfo, SchemaGenerator -from openllm_core._typing_compat import ParamSpec from openllm_core.utils import first_not_none -if t.TYPE_CHECKING: - from attr import AttrsInstance - - import bentoml - -P = ParamSpec('P') OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0' # NOTE: OpenAI schema LIST_MODELS_SCHEMA = """\ @@ -479,7 +472,7 @@ summary: Creates a model response for the given chat conversation. _SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')} -def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]: +def add_schema_definitions(func): append_str = _SCHEMAS.get(func.__name__.lower(), '') if not append_str: return func @@ -490,8 +483,8 @@ def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]: class OpenLLMSchemaGenerator(SchemaGenerator): - def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: - endpoints_info: list[EndpointInfo] = [] + def get_endpoints(self, routes): + endpoints_info = [] for route in routes: if isinstance(route, (Mount, Host)): routes = route.routes or [] @@ -523,7 +516,7 @@ class OpenLLMSchemaGenerator(SchemaGenerator): endpoints_info.append(EndpointInfo(path, method.lower(), func)) return endpoints_info - def get_schema(self, routes: list[BaseRoute], mount_path: str | None = None) -> dict[str, t.Any]: + def get_schema(self, routes, mount_path=None): schema = dict(self.base_schema) schema.setdefault('paths', {}) endpoints_info = self.get_endpoints(routes) @@ -543,13 +536,8 @@ class OpenLLMSchemaGenerator(SchemaGenerator): return schema -def get_generator( - title: str, - components: list[type[AttrsInstance]] | None = None, - tags: list[dict[str, t.Any]] | None = None, - inject: bool = True, -) -> OpenLLMSchemaGenerator: - base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION) +def get_generator(title, components=None, tags=None, inject=True): + base_schema = {'info': {'title': title, 'version': API_VERSION}, 'version': OPENAPI_VERSION} if components and inject: base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}} if tags is not None and tags and inject: @@ -557,12 +545,12 @@ def get_generator( return OpenLLMSchemaGenerator(base_schema) -def component_schema_generator(attr_cls: type[AttrsInstance], description: str | None = None) -> dict[str, t.Any]: - schema: dict[str, t.Any] = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__} +def component_schema_generator(attr_cls, description=None): + schema = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__} schema['description'] = first_not_none( getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}' ) - for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc,type-var] + for field in attr.fields(attr.resolve_types(attr_cls)): attr_type = field.type origin_type = t.get_origin(attr_type) args_type = t.get_args(attr_type) @@ -593,7 +581,7 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str | if 'prop_schema' not in locals(): prop_schema = {'type': schema_type} if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory): - prop_schema['default'] = field.default # type: ignore[arg-type] + prop_schema['default'] = field.default if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)): schema['required'].append(field.name) schema['properties'][field.name] = prop_schema @@ -602,20 +590,15 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str | return schema -class MKSchema: - def __init__(self, it: dict[str, t.Any]) -> None: - self.it = it - - def asdict(self) -> dict[str, t.Any]: - return self.it +_SimpleSchema = types.new_class( + '_SimpleSchema', + (object,), + {}, + lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it}), +) -def append_schemas( - svc: bentoml.Service, - generated_schema: dict[str, t.Any], - tags_order: t.Literal['prepend', 'append'] = 'prepend', - inject: bool = True, -) -> bentoml.Service: +def append_schemas(svc, generated_schema, tags_order='prepend', inject=True): # HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec. from bentoml._internal.service.openapi.specification import OpenAPISpecification @@ -623,7 +606,7 @@ def append_schemas( return svc svc_schema = svc.openapi_spec - if isinstance(svc_schema, (OpenAPISpecification, MKSchema)): + if isinstance(svc_schema, (OpenAPISpecification, _SimpleSchema)): svc_schema = svc_schema.asdict() if 'tags' in generated_schema: if tags_order == 'prepend': @@ -639,12 +622,9 @@ def append_schemas( # HACK: mk this attribute until we have a better way to add starlette schemas. from bentoml._internal.service import openapi - def mk_generate_spec(svc, openapi_version=OPENAPI_VERSION): - return MKSchema(svc_schema) + def _generate_spec(svc, openapi_version=OPENAPI_VERSION): + return _SimpleSchema(svc_schema) - def mk_asdict(self): - return svc_schema - - openapi.generate_spec = mk_generate_spec - OpenAPISpecification.asdict = mk_asdict + openapi.generate_spec = _generate_spec + OpenAPISpecification.asdict = lambda self: svc_schema return svc diff --git a/openllm-python/src/openllm/entrypoints/_openapi.pyi b/openllm-python/src/openllm/entrypoints/_openapi.pyi new file mode 100644 index 00000000..d9d82ce5 --- /dev/null +++ b/openllm-python/src/openllm/entrypoints/_openapi.pyi @@ -0,0 +1,28 @@ +from typing import Any, Callable, Dict, List, Literal, Optional, Type + +from attr import AttrsInstance +from starlette.routing import BaseRoute +from starlette.schemas import EndpointInfo + +from bentoml import Service +from openllm_core._typing_compat import ParamSpec + +P = ParamSpec('P') + +class OpenLLMSchemaGenerator: + base_schema: Dict[str, Any] + def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: ... + def get_schema(self, routes: list[BaseRoute], mount_path: Optional[str] = ...) -> Dict[str, Any]: ... + def parse_docstring(self, func_or_method: Callable[P, Any]) -> Dict[str, Any]: ... + +def add_schema_definitions(func: Callable[P, Any]) -> Callable[P, Any]: ... +def append_schemas( + svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ... +) -> Service: ... +def component_schema_generator(attr_cls: Type[AttrsInstance], description: Optional[str] = ...) -> Dict[str, Any]: ... +def get_generator( + title: str, + components: Optional[List[Type[AttrsInstance]]] = ..., + tags: Optional[List[Dict[str, Any]]] = ..., + inject: bool = ..., +) -> OpenLLMSchemaGenerator: ... diff --git a/openllm-python/src/openllm/entrypoints/cohere.py b/openllm-python/src/openllm/entrypoints/cohere.py index 6aee87af..dc52bfa4 100644 --- a/openllm-python/src/openllm/entrypoints/cohere.py +++ b/openllm-python/src/openllm/entrypoints/cohere.py @@ -3,7 +3,6 @@ import functools import json import logging import traceback -import typing as t from http import HTTPStatus import orjson @@ -54,26 +53,16 @@ schemas = get_generator( ) logger = logging.getLogger(__name__) -if t.TYPE_CHECKING: - from attr import AttrsInstance - from starlette.requests import Request - from starlette.responses import Response - import bentoml - import openllm - from openllm_core._schemas import GenerationOutput - from openllm_core._typing_compat import M, T - - -def jsonify_attr(obj: AttrsInstance) -> str: +def jsonify_attr(obj): return json.dumps(converter.unstructure(obj)) -def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: +def error_response(status_code, message): return JSONResponse(converter.unstructure(CohereErrorResponse(text=message)), status_code=status_code.value) -async def check_model(request: CohereGenerateRequest | CohereChatRequest, model: str) -> JSONResponse | None: +async def check_model(request, model): if request.model is None or request.model == model: return None return error_response( @@ -82,7 +71,7 @@ async def check_model(request: CohereGenerateRequest | CohereChatRequest, model: ) -def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: +def mount_to_svc(svc, llm): app = Starlette( debug=True, routes=[ @@ -90,7 +79,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic '/v1/generate', endpoint=functools.partial(cohere_generate, llm=llm), name='cohere_generate', methods=['POST'] ), Route('/v1/chat', endpoint=functools.partial(cohere_chat, llm=llm), name='cohere_chat', methods=['POST']), - Route('/schema', endpoint=openapi_schema, include_in_schema=False), + Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False), ], ) mount_path = '/cohere' @@ -102,7 +91,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic @add_schema_definitions -async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response: +async def cohere_generate(req, llm): json_str = await req.body() try: request = converter.structure(orjson.loads(json_str), CohereGenerateRequest) @@ -115,7 +104,6 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response: err_check = await check_model(request, llm.llm_type) if err_check is not None: return err_check - request_id = gen_random_uuid('cohere-generate') config = llm.config.with_request(request) @@ -133,10 +121,10 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response: logger.error('Error generating completion: %s', err) return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') - def create_stream_response_json(index: int, text: str, is_finished: bool) -> str: + def create_stream_response_json(index, text, is_finished): return f'{jsonify_attr(StreamingText(index=index, text=text, is_finished=is_finished))}\n' - async def generate_stream_generator() -> t.AsyncGenerator[str, None]: + async def generate_stream_generator(): async for res in result_generator: for output in res.outputs: yield create_stream_response_json(index=output.index, text=output.text, is_finished=output.finish_reason) @@ -146,9 +134,8 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response: if request.stream: return StreamingResponse(generate_stream_generator(), media_type='text/event-stream') # None-streaming case - final_result: GenerationOutput | None = None - texts: list[list[str]] = [[]] * config['n'] - token_ids: list[list[int]] = [[]] * config['n'] + final_result = None + texts, token_ids = [[]] * config['n'], [[]] * config['n'] async for res in result_generator: if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') @@ -164,14 +151,18 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response: for output in final_result.outputs ] ) - response = Generations( - id=request_id, - generations=[ - Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason) - for output in final_result.outputs - ], + return JSONResponse( + converter.unstructure( + Generations( + id=request_id, + generations=[ + Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason) + for output in final_result.outputs + ], + ) + ), + status_code=HTTPStatus.OK.value, ) - return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value) except Exception as err: traceback.print_exc() logger.error('Error generating completion: %s', err) @@ -192,7 +183,7 @@ def _transpile_cohere_chat_messages(request: CohereChatRequest) -> list[dict[str @add_schema_definitions -async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response: +async def cohere_chat(req, llm): json_str = await req.body() try: request = converter.structure(orjson.loads(json_str), CohereChatRequest) @@ -223,9 +214,8 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response: def create_stream_generation_json(index: int, text: str, is_finished: bool) -> str: return f'{jsonify_attr(ChatStreamTextGeneration(index=index, text=text, is_finished=is_finished))}\n' - async def completion_stream_generator() -> t.AsyncGenerator[str, None]: - texts: list[str] = [] - token_ids: list[int] = [] + async def completion_stream_generator(): + texts, token_ids = [], [] yield f'{jsonify_attr(ChatStreamStart(is_finished=False, index=0, generation_id=request_id))}\n' it = None @@ -237,9 +227,7 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response: if it is None: raise ValueError('No response from model.') - num_prompt_tokens = len(t.cast(t.List[int], it.prompt_token_ids)) - num_response_tokens = len(token_ids) - total_tokens = num_prompt_tokens + num_response_tokens + num_prompt_tokens, num_response_tokens = len(it.prompt_token_ids), len(token_ids) json_str = jsonify_attr( ChatStreamEnd( @@ -255,7 +243,7 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response: token_count={ 'prompt_tokens': num_prompt_tokens, 'response_tokens': num_response_tokens, - 'total_tokens': total_tokens, + 'total_tokens': num_prompt_tokens + num_response_tokens, }, ), ) @@ -266,9 +254,8 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response: if request.stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') # Non-streaming case - final_result: GenerationOutput | None = None - texts: list[str] = [] - token_ids: list[int] = [] + final_result = None + texts, token_ids = [], [] async for res in result_generator: if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') @@ -280,28 +267,25 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response: final_result = final_result.with_options( outputs=[final_result.outputs[0].with_options(text=''.join(texts), token_ids=token_ids)] ) - num_prompt_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids)) - num_response_tokens = len(token_ids) - total_tokens = num_prompt_tokens + num_response_tokens - - response = Chat( - response_id=request_id, - message=request.message, - text=''.join(texts), - prompt=prompt, - chat_history=request.chat_history, - token_count={ - 'prompt_tokens': num_prompt_tokens, - 'response_tokens': num_response_tokens, - 'total_tokens': total_tokens, - }, + num_prompt_tokens, num_response_tokens = len(final_result.prompt_token_ids), len(token_ids) + return JSONResponse( + converter.unstructure( + Chat( + response_id=request_id, + message=request.message, + text=''.join(texts), + prompt=prompt, + chat_history=request.chat_history, + token_count={ + 'prompt_tokens': num_prompt_tokens, + 'response_tokens': num_response_tokens, + 'total_tokens': num_prompt_tokens + num_response_tokens, + }, + ) + ), + status_code=HTTPStatus.OK.value, ) - return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value) except Exception as err: traceback.print_exc() logger.error('Error generating completion: %s', err) return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') - - -def openapi_schema(req: Request) -> Response: - return schemas.OpenAPIResponse(req) diff --git a/openllm-python/src/openllm/entrypoints/cohere.pyi b/openllm-python/src/openllm/entrypoints/cohere.pyi new file mode 100644 index 00000000..51f8d46f --- /dev/null +++ b/openllm-python/src/openllm/entrypoints/cohere.pyi @@ -0,0 +1,21 @@ +from http import HTTPStatus +from typing import Optional, Union + +from attr import AttrsInstance +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from bentoml import Service +from openllm_core._typing_compat import M, T + +from .._llm import LLM +from ..protocol.cohere import CohereChatRequest, CohereGenerateRequest + +def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ... +def jsonify_attr(obj: AttrsInstance) -> str: ... +def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ... +async def check_model( + request: Union[CohereGenerateRequest, CohereChatRequest], model: str +) -> Optional[JSONResponse]: ... +async def cohere_generate(req: Request, llm: LLM[M, T]) -> Response: ... +async def cohere_chat(req: Request, llm: LLM[M, T]) -> Response: ... diff --git a/openllm-python/src/openllm/entrypoints/hf.py b/openllm-python/src/openllm/entrypoints/hf.py index faa11acb..d4e9b86f 100644 --- a/openllm-python/src/openllm/entrypoints/hf.py +++ b/openllm-python/src/openllm/entrypoints/hf.py @@ -1,8 +1,5 @@ -from __future__ import annotations import functools import logging -import typing as t -from enum import Enum from http import HTTPStatus import orjson @@ -28,23 +25,14 @@ schemas = get_generator( ) logger = logging.getLogger(__name__) -if t.TYPE_CHECKING: - from peft.config import PeftConfig - from starlette.requests import Request - from starlette.responses import Response - import bentoml - import openllm - from openllm_core._typing_compat import M, T - - -def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: +def mount_to_svc(svc, llm): app = Starlette( debug=True, routes=[ Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']), Route('/adapters', endpoint=functools.partial(hf_adapters, llm=llm), name='adapters', methods=['GET']), - Route('/schema', endpoint=openapi_schema, include_in_schema=False), + Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False), ], ) mount_path = '/hf' @@ -52,7 +40,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append') -def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: +def error_response(status_code, message): return JSONResponse( converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)), status_code=status_code.value, @@ -60,7 +48,7 @@ def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: @add_schema_definitions -async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response: +async def hf_agent(req, llm): json_str = await req.body() try: request = converter.structure(orjson.loads(json_str), AgentRequest) @@ -81,17 +69,13 @@ async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response: @add_schema_definitions -def hf_adapters(req: Request, llm: openllm.LLM[M, T]) -> Response: +def hf_adapters(req, llm): if not llm.has_adapters: return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.') return JSONResponse( { - adapter_tuple[1]: {'adapter_name': k, 'adapter_type': t.cast(Enum, adapter_tuple[0].peft_type).value} - for k, adapter_tuple in t.cast(t.Dict[str, t.Tuple['PeftConfig', str]], dict(*llm.adapter_map.values())).items() + adapter_tuple[1]: {'adapter_name': k, 'adapter_type': adapter_tuple[0].peft_type.value} + for k, adapter_tuple in dict(*llm.adapter_map.values()).items() }, status_code=HTTPStatus.OK.value, ) - - -def openapi_schema(req: Request) -> Response: - return schemas.OpenAPIResponse(req) diff --git a/openllm-python/src/openllm/entrypoints/hf.pyi b/openllm-python/src/openllm/entrypoints/hf.pyi new file mode 100644 index 00000000..d5cc4a26 --- /dev/null +++ b/openllm-python/src/openllm/entrypoints/hf.pyi @@ -0,0 +1,14 @@ +from http import HTTPStatus + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from bentoml import Service +from openllm_core._typing_compat import M, T + +from .._llm import LLM + +def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ... +def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ... +async def hf_agent(req: Request, llm: LLM[M, T]) -> Response: ... +def hf_adapters(req: Request, llm: LLM[M, T]) -> Response: ... diff --git a/openllm-python/src/openllm/entrypoints/openai.py b/openllm-python/src/openllm/entrypoints/openai.py index 0f76f1c5..e6309c59 100644 --- a/openllm-python/src/openllm/entrypoints/openai.py +++ b/openllm-python/src/openllm/entrypoints/openai.py @@ -1,4 +1,3 @@ -from __future__ import annotations import functools import logging import time @@ -57,22 +56,12 @@ schemas = get_generator( ) logger = logging.getLogger(__name__) -if t.TYPE_CHECKING: - from attr import AttrsInstance - from starlette.requests import Request - from starlette.responses import Response - import bentoml - import openllm - from openllm_core._schemas import GenerationOutput - from openllm_core._typing_compat import M, T - - -def jsonify_attr(obj: AttrsInstance) -> str: +def jsonify_attr(obj): return orjson.dumps(converter.unstructure(obj)).decode() -def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: +def error_response(status_code, message): return JSONResponse( { 'error': converter.unstructure( @@ -83,7 +72,7 @@ def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ) -async def check_model(request: CompletionRequest | ChatCompletionRequest, model: str) -> JSONResponse | None: +async def check_model(request, model): if request.model == model: return None return error_response( @@ -92,9 +81,7 @@ async def check_model(request: CompletionRequest | ChatCompletionRequest, model: ) -def create_logprobs( - token_ids: list[int], id_logprobs: list[dict[int, float]], initial_text_offset: int = 0, *, llm: openllm.LLM[M, T] -) -> LogProbs: +def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm): # Create OpenAI-style logprobs. logprobs = LogProbs() last_token_len = 0 @@ -112,24 +99,23 @@ def create_logprobs( return logprobs -def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: +def mount_to_svc(svc, llm): app = Starlette( debug=True, routes=[ Route('/models', functools.partial(list_models, llm=llm), methods=['GET']), Route('/completions', functools.partial(completions, llm=llm), methods=['POST']), Route('/chat/completions', functools.partial(chat_completions, llm=llm), methods=['POST']), - Route('/schema', endpoint=openapi_schema, include_in_schema=False), + Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False), ], ) - mount_path = '/v1' - svc.mount_asgi_app(app, path=mount_path) - return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path)) + svc.mount_asgi_app(app, path='/v1') + return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path='/v1')) # GET /v1/models @add_schema_definitions -def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response: +def list_models(_, llm): return JSONResponse( converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value ) @@ -137,7 +123,7 @@ def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response: # POST /v1/chat/completions @add_schema_definitions -async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response: +async def chat_completions(req, llm): # TODO: Check for length based on model context_length json_str = await req.body() try: @@ -166,7 +152,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response: logger.error('Error generating completion: %s', err) return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') - def create_stream_response_json(index: int, text: str, finish_reason: str | None = None) -> str: + def create_stream_response_json(index, text, finish_reason=None): return jsonify_attr( ChatCompletionStreamResponse( id=request_id, @@ -178,7 +164,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response: ) ) - async def completion_stream_generator() -> t.AsyncGenerator[str, None]: + async def completion_stream_generator(): # first chunk with role for i in range(config['n']): yield f"data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role='assistant'), finish_reason=None)], model=model_name))}\n\n" @@ -195,9 +181,8 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response: if request.stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') # Non-streaming case - final_result: GenerationOutput | None = None - texts: list[list[str]] = [[]] * config['n'] - token_ids: list[list[int]] = [[]] * config['n'] + final_result = None + texts, token_ids = [[]] * config['n'], [[]] * config['n'] async for res in result_generator: if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') @@ -221,15 +206,9 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response: ) for output in final_result.outputs ] - num_prompt_tokens, num_generated_tokens = ( - len(t.cast(t.List[int], final_result.prompt_token_ids)), - sum(len(output.token_ids) for output in final_result.outputs), - ) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) + num_prompt_tokens = len(final_result.prompt_token_ids) + num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs) + usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens) response = ChatCompletionResponse( id=request_id, created=created_time, model=model_name, usage=usage, choices=choices ) @@ -254,7 +233,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response: # POST /v1/completions @add_schema_definitions -async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response: +async def completions(req, llm): # TODO: Check for length based on model context_length json_str = await req.body() try: @@ -300,9 +279,7 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response: # TODO: support use_beam_search stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of']) - def create_stream_response_json( - index: int, text: str, logprobs: LogProbs | None = None, finish_reason: str | None = None - ) -> str: + def create_stream_response_json(index, text, logprobs=None, finish_reason=None): return jsonify_attr( CompletionStreamResponse( id=request_id, @@ -314,16 +291,14 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response: ) ) - async def completion_stream_generator() -> t.AsyncGenerator[str, None]: + async def completion_stream_generator(): previous_num_tokens = [0] * config['n'] async for res in result_generator: for output in res.outputs: i = output.index if request.logprobs is not None: logprobs = create_logprobs( - token_ids=output.token_ids, - id_logprobs=t.cast(SampleLogprobs, output.logprobs)[previous_num_tokens[i] :], - llm=llm, + token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i] :], llm=llm ) else: logprobs = None @@ -339,9 +314,8 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response: if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') # Non-streaming case - final_result: GenerationOutput | None = None - texts: list[list[str]] = [[]] * config['n'] - token_ids: list[list[int]] = [[]] * config['n'] + final_result = None + texts, token_ids = [[]] * config['n'], [[]] * config['n'] async for res in result_generator: if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') @@ -358,7 +332,7 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response: ] ) - choices: list[CompletionResponseChoice] = [] + choices = [] for output in final_result.outputs: if request.logprobs is not None: logprobs = create_logprobs( @@ -371,15 +345,9 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response: ) choices.append(choice_data) - num_prompt_tokens = len( - t.cast(t.List[int], final_result.prompt_token_ids) - ) # XXX: We will always return prompt_token_ids, so this won't be None + num_prompt_tokens = len(final_result.prompt_token_ids) num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) + usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens) response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices) if request.stream: @@ -398,7 +366,3 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response: traceback.print_exc() logger.error('Error generating completion: %s', err) return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') - - -def openapi_schema(req: Request) -> Response: - return schemas.OpenAPIResponse(req) diff --git a/openllm-python/src/openllm/entrypoints/openai.pyi b/openllm-python/src/openllm/entrypoints/openai.pyi new file mode 100644 index 00000000..0935c227 --- /dev/null +++ b/openllm-python/src/openllm/entrypoints/openai.pyi @@ -0,0 +1,25 @@ +from http import HTTPStatus +from typing import Dict, List, Optional, Union + +from attr import AttrsInstance +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from bentoml import Service +from openllm_core._typing_compat import M, T + +from .._llm import LLM +from ..protocol.openai import ChatCompletionRequest, CompletionRequest, LogProbs + +def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ... +def jsonify_attr(obj: AttrsInstance) -> str: ... +def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ... +async def check_model( + request: Union[CompletionRequest, ChatCompletionRequest], model: str +) -> Optional[JSONResponse]: ... +def create_logprobs( + token_ids: List[int], id_logprobs: List[Dict[int, float]], initial_text_offset: int = ..., *, llm: LLM[M, T] +) -> LogProbs: ... +def list_models(req: Request, llm: LLM[M, T]) -> Response: ... +async def chat_completions(req: Request, llm: LLM[M, T]) -> Response: ... +async def completions(req: Request, llm: LLM[M, T]) -> Response: ... diff --git a/openllm-python/src/openllm_cli/_factory.py b/openllm-python/src/openllm_cli/_factory.py index b55f27eb..374b664c 100644 --- a/openllm-python/src/openllm_cli/_factory.py +++ b/openllm-python/src/openllm_cli/_factory.py @@ -8,7 +8,7 @@ import click import click_option_group as cog import inflection from bentoml_cli.utils import BentoMLCommandGroup -from click import ClickException, shell_completion as sc +from click import shell_completion as sc import bentoml import openllm @@ -22,7 +22,7 @@ from openllm_core._typing_compat import ( ParamSpec, get_literal_args, ) -from openllm_core.utils import DEBUG +from openllm_core.utils import DEBUG, resolve_user_filepath class _OpenLLM_GenericInternalConfig(LLMConfig): @@ -125,12 +125,11 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ... # try to resolve the full path if users pass in relative, # currently only support one level of resolve path with current directory try: - adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd()) + adapter_id = resolve_user_filepath(adapter_id, os.getcwd()) except FileNotFoundError: pass - if len(adapter_name) == 0: - raise ClickException(f'Adapter name is required for {adapter_id}') - ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] + name = adapter_name[0] if len(adapter_name) > 0 else 'default' + ctx.params[_adapter_mapping_key][adapter_id] = name return None @@ -171,32 +170,7 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC help='Assign GPU devices (if available)', show_envvar=True, ), - cog.optgroup.group( - 'Fine-tuning related options', - help="""\ - Note that the argument `--adapter-id` can accept the following format: - - - `--adapter-id /path/to/adapter` (local adapter) - - - `--adapter-id remote/adapter` (remote adapter from HuggingFace Hub) - - - `--adapter-id remote/adapter:eng_lora` (two previous adapter options with the given adapter_name) - - ```bash - - $ openllm start opt --adapter-id /path/to/adapter_dir --adapter-id remote/adapter:eng_lora - - ``` - """, - ), - cog.optgroup.option( - '--adapter-id', - default=None, - help='Optional name or path for given LoRA adapter', - multiple=True, - callback=_id_callback, - metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]', - ), + adapter_id_option(factory=cog.optgroup), click.option('--return-process', is_flag=True, default=False, help='Internal use only.', hidden=True), ) return composed(fn) @@ -285,6 +259,17 @@ cli_option = functools.partial(_click_factory_type, attr='option') cli_argument = functools.partial(_click_factory_type, attr='argument') +def adapter_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option( + '--adapter-id', + default=None, + help='Optional name or path for given LoRA adapter', + multiple=True, + callback=_id_callback, + metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]', + ) + + def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( '--cors/--no-cors', diff --git a/openllm-python/src/openllm_cli/entrypoint.py b/openllm-python/src/openllm_cli/entrypoint.py index c809c110..4d705656 100644 --- a/openllm-python/src/openllm_cli/entrypoint.py +++ b/openllm-python/src/openllm_cli/entrypoint.py @@ -691,7 +691,7 @@ def pretty_print(line: str): elif 'ERROR' in line: caller = termui.error else: - caller = caller = functools.partial(termui.echo, fg=None) + caller = functools.partial(termui.echo, fg=None) caller(line.strip()) @@ -1026,6 +1026,7 @@ def build_command( 'OPENLLM_BACKEND': llm.__llm_backend__, 'OPENLLM_SERIALIZATION': llm._serialisation, 'OPENLLM_MODEL_ID': llm.model_id, + 'OPENLLM_ADAPTER_MAP': orjson.dumps(None).decode(), } ) if llm.quantise: @@ -1035,13 +1036,8 @@ def build_command( if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template - # NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError - # during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path try: assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything - # FIX: This is a patch for _service_vars injection - if 'OPENLLM_ADAPTER_MAP' not in os.environ: - os.environ['OPENLLM_ADAPTER_MAP'] = orjson.dumps(None).decode() labels = dict(llm.identifying_params) labels.update({'_type': llm.llm_type, '_framework': llm.__llm_backend__}) @@ -1190,14 +1186,8 @@ class ModelItem(t.TypedDict): @cli.command() -@click.option( - '--show-available', - is_flag=True, - default=True, - hidden=True, - help="Show available models in local store (mutually exclusive with '-o porcelain').", -) -def models_command(show_available: bool) -> dict[t.LiteralString, ModelItem]: +@click.option('--show-available', is_flag=True, default=True, hidden=True) +def models_command(**_: t.Any) -> dict[t.LiteralString, ModelItem]: """List all supported models. \b