mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-11 18:09:52 -04:00
types: update stubs for remaining entrypoints (#667)
* perf(type): static OpenAI types definition Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * feat: add hf types Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * types: update remaining missing stubs Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,33 +1,13 @@
|
||||
"""Entrypoint for all third-party apps.
|
||||
|
||||
Currently support OpenAI compatible API.
|
||||
|
||||
Each module should implement the following API:
|
||||
|
||||
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import typing as t
|
||||
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
import openllm
|
||||
_import_structure = {'openai': [], 'hf': [], 'cohere': []}
|
||||
|
||||
|
||||
class IntegrationModule(t.Protocol):
|
||||
def mount_to_svc(self, svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service: ...
|
||||
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'openai': [], 'hf': [], 'cohere': []}
|
||||
|
||||
|
||||
def mount_entrypoints(svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service:
|
||||
def mount_entrypoints(svc, llm):
|
||||
for module_name in _import_structure:
|
||||
module = t.cast(IntegrationModule, importlib.import_module(f'.{module_name}', __name__))
|
||||
module = importlib.import_module(f'.{module_name}', __name__)
|
||||
svc = module.mount_to_svc(svc, llm)
|
||||
return svc
|
||||
|
||||
|
||||
16
openllm-python/src/openllm/entrypoints/__init__.pyi
Normal file
16
openllm-python/src/openllm/entrypoints/__init__.pyi
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Entrypoint for all third-party apps.
|
||||
|
||||
Currently support OpenAI, Cohere compatible API.
|
||||
|
||||
Each module should implement the following API:
|
||||
|
||||
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
|
||||
"""
|
||||
|
||||
from bentoml import Service
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from . import cohere as cohere, hf as hf, openai as openai
|
||||
from .._llm import LLM
|
||||
|
||||
def mount_entrypoints(svc: Service, llm: LLM[M, T]) -> Service: ...
|
||||
@@ -1,21 +1,14 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
from starlette.routing import BaseRoute, Host, Mount, Route
|
||||
from starlette.routing import Host, Mount, Route
|
||||
from starlette.schemas import EndpointInfo, SchemaGenerator
|
||||
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core.utils import first_not_none
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import AttrsInstance
|
||||
|
||||
import bentoml
|
||||
|
||||
P = ParamSpec('P')
|
||||
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
|
||||
# NOTE: OpenAI schema
|
||||
LIST_MODELS_SCHEMA = """\
|
||||
@@ -479,7 +472,7 @@ summary: Creates a model response for the given chat conversation.
|
||||
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
|
||||
|
||||
|
||||
def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
def add_schema_definitions(func):
|
||||
append_str = _SCHEMAS.get(func.__name__.lower(), '')
|
||||
if not append_str:
|
||||
return func
|
||||
@@ -490,8 +483,8 @@ def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
|
||||
|
||||
class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
|
||||
endpoints_info: list[EndpointInfo] = []
|
||||
def get_endpoints(self, routes):
|
||||
endpoints_info = []
|
||||
for route in routes:
|
||||
if isinstance(route, (Mount, Host)):
|
||||
routes = route.routes or []
|
||||
@@ -523,7 +516,7 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), func))
|
||||
return endpoints_info
|
||||
|
||||
def get_schema(self, routes: list[BaseRoute], mount_path: str | None = None) -> dict[str, t.Any]:
|
||||
def get_schema(self, routes, mount_path=None):
|
||||
schema = dict(self.base_schema)
|
||||
schema.setdefault('paths', {})
|
||||
endpoints_info = self.get_endpoints(routes)
|
||||
@@ -543,13 +536,8 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
return schema
|
||||
|
||||
|
||||
def get_generator(
|
||||
title: str,
|
||||
components: list[type[AttrsInstance]] | None = None,
|
||||
tags: list[dict[str, t.Any]] | None = None,
|
||||
inject: bool = True,
|
||||
) -> OpenLLMSchemaGenerator:
|
||||
base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION)
|
||||
def get_generator(title, components=None, tags=None, inject=True):
|
||||
base_schema = {'info': {'title': title, 'version': API_VERSION}, 'version': OPENAPI_VERSION}
|
||||
if components and inject:
|
||||
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
|
||||
if tags is not None and tags and inject:
|
||||
@@ -557,12 +545,12 @@ def get_generator(
|
||||
return OpenLLMSchemaGenerator(base_schema)
|
||||
|
||||
|
||||
def component_schema_generator(attr_cls: type[AttrsInstance], description: str | None = None) -> dict[str, t.Any]:
|
||||
schema: dict[str, t.Any] = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
|
||||
def component_schema_generator(attr_cls, description=None):
|
||||
schema = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
|
||||
schema['description'] = first_not_none(
|
||||
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
|
||||
)
|
||||
for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc,type-var]
|
||||
for field in attr.fields(attr.resolve_types(attr_cls)):
|
||||
attr_type = field.type
|
||||
origin_type = t.get_origin(attr_type)
|
||||
args_type = t.get_args(attr_type)
|
||||
@@ -593,7 +581,7 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
|
||||
if 'prop_schema' not in locals():
|
||||
prop_schema = {'type': schema_type}
|
||||
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory):
|
||||
prop_schema['default'] = field.default # type: ignore[arg-type]
|
||||
prop_schema['default'] = field.default
|
||||
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)):
|
||||
schema['required'].append(field.name)
|
||||
schema['properties'][field.name] = prop_schema
|
||||
@@ -602,20 +590,15 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
|
||||
return schema
|
||||
|
||||
|
||||
class MKSchema:
|
||||
def __init__(self, it: dict[str, t.Any]) -> None:
|
||||
self.it = it
|
||||
|
||||
def asdict(self) -> dict[str, t.Any]:
|
||||
return self.it
|
||||
_SimpleSchema = types.new_class(
|
||||
'_SimpleSchema',
|
||||
(object,),
|
||||
{},
|
||||
lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it}),
|
||||
)
|
||||
|
||||
|
||||
def append_schemas(
|
||||
svc: bentoml.Service,
|
||||
generated_schema: dict[str, t.Any],
|
||||
tags_order: t.Literal['prepend', 'append'] = 'prepend',
|
||||
inject: bool = True,
|
||||
) -> bentoml.Service:
|
||||
def append_schemas(svc, generated_schema, tags_order='prepend', inject=True):
|
||||
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
|
||||
from bentoml._internal.service.openapi.specification import OpenAPISpecification
|
||||
|
||||
@@ -623,7 +606,7 @@ def append_schemas(
|
||||
return svc
|
||||
|
||||
svc_schema = svc.openapi_spec
|
||||
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
|
||||
if isinstance(svc_schema, (OpenAPISpecification, _SimpleSchema)):
|
||||
svc_schema = svc_schema.asdict()
|
||||
if 'tags' in generated_schema:
|
||||
if tags_order == 'prepend':
|
||||
@@ -639,12 +622,9 @@ def append_schemas(
|
||||
# HACK: mk this attribute until we have a better way to add starlette schemas.
|
||||
from bentoml._internal.service import openapi
|
||||
|
||||
def mk_generate_spec(svc, openapi_version=OPENAPI_VERSION):
|
||||
return MKSchema(svc_schema)
|
||||
def _generate_spec(svc, openapi_version=OPENAPI_VERSION):
|
||||
return _SimpleSchema(svc_schema)
|
||||
|
||||
def mk_asdict(self):
|
||||
return svc_schema
|
||||
|
||||
openapi.generate_spec = mk_generate_spec
|
||||
OpenAPISpecification.asdict = mk_asdict
|
||||
openapi.generate_spec = _generate_spec
|
||||
OpenAPISpecification.asdict = lambda self: svc_schema
|
||||
return svc
|
||||
|
||||
28
openllm-python/src/openllm/entrypoints/_openapi.pyi
Normal file
28
openllm-python/src/openllm/entrypoints/_openapi.pyi
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Type
|
||||
|
||||
from attr import AttrsInstance
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.schemas import EndpointInfo
|
||||
|
||||
from bentoml import Service
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
class OpenLLMSchemaGenerator:
|
||||
base_schema: Dict[str, Any]
|
||||
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: ...
|
||||
def get_schema(self, routes: list[BaseRoute], mount_path: Optional[str] = ...) -> Dict[str, Any]: ...
|
||||
def parse_docstring(self, func_or_method: Callable[P, Any]) -> Dict[str, Any]: ...
|
||||
|
||||
def add_schema_definitions(func: Callable[P, Any]) -> Callable[P, Any]: ...
|
||||
def append_schemas(
|
||||
svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...
|
||||
) -> Service: ...
|
||||
def component_schema_generator(attr_cls: Type[AttrsInstance], description: Optional[str] = ...) -> Dict[str, Any]: ...
|
||||
def get_generator(
|
||||
title: str,
|
||||
components: Optional[List[Type[AttrsInstance]]] = ...,
|
||||
tags: Optional[List[Dict[str, Any]]] = ...,
|
||||
inject: bool = ...,
|
||||
) -> OpenLLMSchemaGenerator: ...
|
||||
@@ -3,7 +3,6 @@ import functools
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
import typing as t
|
||||
from http import HTTPStatus
|
||||
|
||||
import orjson
|
||||
@@ -54,26 +53,16 @@ schemas = get_generator(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import AttrsInstance
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
|
||||
def jsonify_attr(obj: AttrsInstance) -> str:
|
||||
def jsonify_attr(obj):
|
||||
return json.dumps(converter.unstructure(obj))
|
||||
|
||||
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(converter.unstructure(CohereErrorResponse(text=message)), status_code=status_code.value)
|
||||
|
||||
|
||||
async def check_model(request: CohereGenerateRequest | CohereChatRequest, model: str) -> JSONResponse | None:
|
||||
async def check_model(request, model):
|
||||
if request.model is None or request.model == model:
|
||||
return None
|
||||
return error_response(
|
||||
@@ -82,7 +71,7 @@ async def check_model(request: CohereGenerateRequest | CohereChatRequest, model:
|
||||
)
|
||||
|
||||
|
||||
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
|
||||
def mount_to_svc(svc, llm):
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
@@ -90,7 +79,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
|
||||
'/v1/generate', endpoint=functools.partial(cohere_generate, llm=llm), name='cohere_generate', methods=['POST']
|
||||
),
|
||||
Route('/v1/chat', endpoint=functools.partial(cohere_chat, llm=llm), name='cohere_chat', methods=['POST']),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
|
||||
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
|
||||
],
|
||||
)
|
||||
mount_path = '/cohere'
|
||||
@@ -102,7 +91,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
async def cohere_generate(req, llm):
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), CohereGenerateRequest)
|
||||
@@ -115,7 +104,6 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
request_id = gen_random_uuid('cohere-generate')
|
||||
config = llm.config.with_request(request)
|
||||
|
||||
@@ -133,10 +121,10 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_response_json(index: int, text: str, is_finished: bool) -> str:
|
||||
def create_stream_response_json(index, text, is_finished):
|
||||
return f'{jsonify_attr(StreamingText(index=index, text=text, is_finished=is_finished))}\n'
|
||||
|
||||
async def generate_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
async def generate_stream_generator():
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
yield create_stream_response_json(index=output.index, text=output.text, is_finished=output.finish_reason)
|
||||
@@ -146,9 +134,8 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
if request.stream:
|
||||
return StreamingResponse(generate_stream_generator(), media_type='text/event-stream')
|
||||
# None-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[list[str]] = [[]] * config['n']
|
||||
token_ids: list[list[int]] = [[]] * config['n']
|
||||
final_result = None
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
@@ -164,14 +151,18 @@ async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
for output in final_result.outputs
|
||||
]
|
||||
)
|
||||
response = Generations(
|
||||
id=request_id,
|
||||
generations=[
|
||||
Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason)
|
||||
for output in final_result.outputs
|
||||
],
|
||||
return JSONResponse(
|
||||
converter.unstructure(
|
||||
Generations(
|
||||
id=request_id,
|
||||
generations=[
|
||||
Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason)
|
||||
for output in final_result.outputs
|
||||
],
|
||||
)
|
||||
),
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
@@ -192,7 +183,7 @@ def _transpile_cohere_chat_messages(request: CohereChatRequest) -> list[dict[str
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
async def cohere_chat(req, llm):
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), CohereChatRequest)
|
||||
@@ -223,9 +214,8 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
def create_stream_generation_json(index: int, text: str, is_finished: bool) -> str:
|
||||
return f'{jsonify_attr(ChatStreamTextGeneration(index=index, text=text, is_finished=is_finished))}\n'
|
||||
|
||||
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
texts: list[str] = []
|
||||
token_ids: list[int] = []
|
||||
async def completion_stream_generator():
|
||||
texts, token_ids = [], []
|
||||
yield f'{jsonify_attr(ChatStreamStart(is_finished=False, index=0, generation_id=request_id))}\n'
|
||||
|
||||
it = None
|
||||
@@ -237,9 +227,7 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
|
||||
if it is None:
|
||||
raise ValueError('No response from model.')
|
||||
num_prompt_tokens = len(t.cast(t.List[int], it.prompt_token_ids))
|
||||
num_response_tokens = len(token_ids)
|
||||
total_tokens = num_prompt_tokens + num_response_tokens
|
||||
num_prompt_tokens, num_response_tokens = len(it.prompt_token_ids), len(token_ids)
|
||||
|
||||
json_str = jsonify_attr(
|
||||
ChatStreamEnd(
|
||||
@@ -255,7 +243,7 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
token_count={
|
||||
'prompt_tokens': num_prompt_tokens,
|
||||
'response_tokens': num_response_tokens,
|
||||
'total_tokens': total_tokens,
|
||||
'total_tokens': num_prompt_tokens + num_response_tokens,
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -266,9 +254,8 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[str] = []
|
||||
token_ids: list[int] = []
|
||||
final_result = None
|
||||
texts, token_ids = [], []
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
@@ -280,28 +267,25 @@ async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
final_result = final_result.with_options(
|
||||
outputs=[final_result.outputs[0].with_options(text=''.join(texts), token_ids=token_ids)]
|
||||
)
|
||||
num_prompt_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids))
|
||||
num_response_tokens = len(token_ids)
|
||||
total_tokens = num_prompt_tokens + num_response_tokens
|
||||
|
||||
response = Chat(
|
||||
response_id=request_id,
|
||||
message=request.message,
|
||||
text=''.join(texts),
|
||||
prompt=prompt,
|
||||
chat_history=request.chat_history,
|
||||
token_count={
|
||||
'prompt_tokens': num_prompt_tokens,
|
||||
'response_tokens': num_response_tokens,
|
||||
'total_tokens': total_tokens,
|
||||
},
|
||||
num_prompt_tokens, num_response_tokens = len(final_result.prompt_token_ids), len(token_ids)
|
||||
return JSONResponse(
|
||||
converter.unstructure(
|
||||
Chat(
|
||||
response_id=request_id,
|
||||
message=request.message,
|
||||
text=''.join(texts),
|
||||
prompt=prompt,
|
||||
chat_history=request.chat_history,
|
||||
token_count={
|
||||
'prompt_tokens': num_prompt_tokens,
|
||||
'response_tokens': num_response_tokens,
|
||||
'total_tokens': num_prompt_tokens + num_response_tokens,
|
||||
},
|
||||
)
|
||||
),
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
def openapi_schema(req: Request) -> Response:
|
||||
return schemas.OpenAPIResponse(req)
|
||||
|
||||
21
openllm-python/src/openllm/entrypoints/cohere.pyi
Normal file
21
openllm-python/src/openllm/entrypoints/cohere.pyi
Normal file
@@ -0,0 +1,21 @@
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union
|
||||
|
||||
from attr import AttrsInstance
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from bentoml import Service
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from .._llm import LLM
|
||||
from ..protocol.cohere import CohereChatRequest, CohereGenerateRequest
|
||||
|
||||
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
|
||||
def jsonify_attr(obj: AttrsInstance) -> str: ...
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
|
||||
async def check_model(
|
||||
request: Union[CohereGenerateRequest, CohereChatRequest], model: str
|
||||
) -> Optional[JSONResponse]: ...
|
||||
async def cohere_generate(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
async def cohere_chat(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
@@ -1,8 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import logging
|
||||
import typing as t
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
|
||||
import orjson
|
||||
@@ -28,23 +25,14 @@ schemas = get_generator(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from peft.config import PeftConfig
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
|
||||
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
|
||||
def mount_to_svc(svc, llm):
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
|
||||
Route('/adapters', endpoint=functools.partial(hf_adapters, llm=llm), name='adapters', methods=['GET']),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
|
||||
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
|
||||
],
|
||||
)
|
||||
mount_path = '/hf'
|
||||
@@ -52,7 +40,7 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
|
||||
|
||||
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(
|
||||
converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)),
|
||||
status_code=status_code.value,
|
||||
@@ -60,7 +48,7 @@ def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
async def hf_agent(req, llm):
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), AgentRequest)
|
||||
@@ -81,17 +69,13 @@ async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
def hf_adapters(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
def hf_adapters(req, llm):
|
||||
if not llm.has_adapters:
|
||||
return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
|
||||
return JSONResponse(
|
||||
{
|
||||
adapter_tuple[1]: {'adapter_name': k, 'adapter_type': t.cast(Enum, adapter_tuple[0].peft_type).value}
|
||||
for k, adapter_tuple in t.cast(t.Dict[str, t.Tuple['PeftConfig', str]], dict(*llm.adapter_map.values())).items()
|
||||
adapter_tuple[1]: {'adapter_name': k, 'adapter_type': adapter_tuple[0].peft_type.value}
|
||||
for k, adapter_tuple in dict(*llm.adapter_map.values()).items()
|
||||
},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
|
||||
|
||||
def openapi_schema(req: Request) -> Response:
|
||||
return schemas.OpenAPIResponse(req)
|
||||
|
||||
14
openllm-python/src/openllm/entrypoints/hf.pyi
Normal file
14
openllm-python/src/openllm/entrypoints/hf.pyi
Normal file
@@ -0,0 +1,14 @@
|
||||
from http import HTTPStatus
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from bentoml import Service
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from .._llm import LLM
|
||||
|
||||
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
|
||||
async def hf_agent(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
def hf_adapters(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
@@ -57,22 +56,12 @@ schemas = get_generator(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import AttrsInstance
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
|
||||
def jsonify_attr(obj: AttrsInstance) -> str:
|
||||
def jsonify_attr(obj):
|
||||
return orjson.dumps(converter.unstructure(obj)).decode()
|
||||
|
||||
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(
|
||||
{
|
||||
'error': converter.unstructure(
|
||||
@@ -83,7 +72,7 @@ def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
async def check_model(request: CompletionRequest | ChatCompletionRequest, model: str) -> JSONResponse | None:
|
||||
async def check_model(request, model):
|
||||
if request.model == model:
|
||||
return None
|
||||
return error_response(
|
||||
@@ -92,9 +81,7 @@ async def check_model(request: CompletionRequest | ChatCompletionRequest, model:
|
||||
)
|
||||
|
||||
|
||||
def create_logprobs(
|
||||
token_ids: list[int], id_logprobs: list[dict[int, float]], initial_text_offset: int = 0, *, llm: openllm.LLM[M, T]
|
||||
) -> LogProbs:
|
||||
def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm):
|
||||
# Create OpenAI-style logprobs.
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
@@ -112,24 +99,23 @@ def create_logprobs(
|
||||
return logprobs
|
||||
|
||||
|
||||
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
|
||||
def mount_to_svc(svc, llm):
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
|
||||
Route('/completions', functools.partial(completions, llm=llm), methods=['POST']),
|
||||
Route('/chat/completions', functools.partial(chat_completions, llm=llm), methods=['POST']),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
|
||||
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
|
||||
],
|
||||
)
|
||||
mount_path = '/v1'
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path))
|
||||
svc.mount_asgi_app(app, path='/v1')
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path='/v1'))
|
||||
|
||||
|
||||
# GET /v1/models
|
||||
@add_schema_definitions
|
||||
def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
def list_models(_, llm):
|
||||
return JSONResponse(
|
||||
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
|
||||
)
|
||||
@@ -137,7 +123,7 @@ def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
|
||||
# POST /v1/chat/completions
|
||||
@add_schema_definitions
|
||||
async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
async def chat_completions(req, llm):
|
||||
# TODO: Check for length based on model context_length
|
||||
json_str = await req.body()
|
||||
try:
|
||||
@@ -166,7 +152,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_response_json(index: int, text: str, finish_reason: str | None = None) -> str:
|
||||
def create_stream_response_json(index, text, finish_reason=None):
|
||||
return jsonify_attr(
|
||||
ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
@@ -178,7 +164,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
)
|
||||
)
|
||||
|
||||
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
async def completion_stream_generator():
|
||||
# first chunk with role
|
||||
for i in range(config['n']):
|
||||
yield f"data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role='assistant'), finish_reason=None)], model=model_name))}\n\n"
|
||||
@@ -195,9 +181,8 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[list[str]] = [[]] * config['n']
|
||||
token_ids: list[list[int]] = [[]] * config['n']
|
||||
final_result = None
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
@@ -221,15 +206,9 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
)
|
||||
for output in final_result.outputs
|
||||
]
|
||||
num_prompt_tokens, num_generated_tokens = (
|
||||
len(t.cast(t.List[int], final_result.prompt_token_ids)),
|
||||
sum(len(output.token_ids) for output in final_result.outputs),
|
||||
)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
|
||||
)
|
||||
@@ -254,7 +233,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
|
||||
# POST /v1/completions
|
||||
@add_schema_definitions
|
||||
async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
async def completions(req, llm):
|
||||
# TODO: Check for length based on model context_length
|
||||
json_str = await req.body()
|
||||
try:
|
||||
@@ -300,9 +279,7 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
# TODO: support use_beam_search
|
||||
stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of'])
|
||||
|
||||
def create_stream_response_json(
|
||||
index: int, text: str, logprobs: LogProbs | None = None, finish_reason: str | None = None
|
||||
) -> str:
|
||||
def create_stream_response_json(index, text, logprobs=None, finish_reason=None):
|
||||
return jsonify_attr(
|
||||
CompletionStreamResponse(
|
||||
id=request_id,
|
||||
@@ -314,16 +291,14 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
)
|
||||
)
|
||||
|
||||
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
async def completion_stream_generator():
|
||||
previous_num_tokens = [0] * config['n']
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
token_ids=output.token_ids,
|
||||
id_logprobs=t.cast(SampleLogprobs, output.logprobs)[previous_num_tokens[i] :],
|
||||
llm=llm,
|
||||
token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i] :], llm=llm
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
@@ -339,9 +314,8 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
if stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[list[str]] = [[]] * config['n']
|
||||
token_ids: list[list[int]] = [[]] * config['n']
|
||||
final_result = None
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
@@ -358,7 +332,7 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
]
|
||||
)
|
||||
|
||||
choices: list[CompletionResponseChoice] = []
|
||||
choices = []
|
||||
for output in final_result.outputs:
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
@@ -371,15 +345,9 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(
|
||||
t.cast(t.List[int], final_result.prompt_token_ids)
|
||||
) # XXX: We will always return prompt_token_ids, so this won't be None
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
|
||||
response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
|
||||
if request.stream:
|
||||
@@ -398,7 +366,3 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
def openapi_schema(req: Request) -> Response:
|
||||
return schemas.OpenAPIResponse(req)
|
||||
|
||||
25
openllm-python/src/openllm/entrypoints/openai.pyi
Normal file
25
openllm-python/src/openllm/entrypoints/openai.pyi
Normal file
@@ -0,0 +1,25 @@
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from attr import AttrsInstance
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from bentoml import Service
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from .._llm import LLM
|
||||
from ..protocol.openai import ChatCompletionRequest, CompletionRequest, LogProbs
|
||||
|
||||
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
|
||||
def jsonify_attr(obj: AttrsInstance) -> str: ...
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
|
||||
async def check_model(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest], model: str
|
||||
) -> Optional[JSONResponse]: ...
|
||||
def create_logprobs(
|
||||
token_ids: List[int], id_logprobs: List[Dict[int, float]], initial_text_offset: int = ..., *, llm: LLM[M, T]
|
||||
) -> LogProbs: ...
|
||||
def list_models(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
async def chat_completions(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
async def completions(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
@@ -8,7 +8,7 @@ import click
|
||||
import click_option_group as cog
|
||||
import inflection
|
||||
from bentoml_cli.utils import BentoMLCommandGroup
|
||||
from click import ClickException, shell_completion as sc
|
||||
from click import shell_completion as sc
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
@@ -22,7 +22,7 @@ from openllm_core._typing_compat import (
|
||||
ParamSpec,
|
||||
get_literal_args,
|
||||
)
|
||||
from openllm_core.utils import DEBUG
|
||||
from openllm_core.utils import DEBUG, resolve_user_filepath
|
||||
|
||||
|
||||
class _OpenLLM_GenericInternalConfig(LLMConfig):
|
||||
@@ -125,12 +125,11 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
# try to resolve the full path if users pass in relative,
|
||||
# currently only support one level of resolve path with current directory
|
||||
try:
|
||||
adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
|
||||
adapter_id = resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
if len(adapter_name) == 0:
|
||||
raise ClickException(f'Adapter name is required for {adapter_id}')
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0]
|
||||
name = adapter_name[0] if len(adapter_name) > 0 else 'default'
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = name
|
||||
return None
|
||||
|
||||
|
||||
@@ -171,32 +170,7 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC
|
||||
help='Assign GPU devices (if available)',
|
||||
show_envvar=True,
|
||||
),
|
||||
cog.optgroup.group(
|
||||
'Fine-tuning related options',
|
||||
help="""\
|
||||
Note that the argument `--adapter-id` can accept the following format:
|
||||
|
||||
- `--adapter-id /path/to/adapter` (local adapter)
|
||||
|
||||
- `--adapter-id remote/adapter` (remote adapter from HuggingFace Hub)
|
||||
|
||||
- `--adapter-id remote/adapter:eng_lora` (two previous adapter options with the given adapter_name)
|
||||
|
||||
```bash
|
||||
|
||||
$ openllm start opt --adapter-id /path/to/adapter_dir --adapter-id remote/adapter:eng_lora
|
||||
|
||||
```
|
||||
""",
|
||||
),
|
||||
cog.optgroup.option(
|
||||
'--adapter-id',
|
||||
default=None,
|
||||
help='Optional name or path for given LoRA adapter',
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
|
||||
),
|
||||
adapter_id_option(factory=cog.optgroup),
|
||||
click.option('--return-process', is_flag=True, default=False, help='Internal use only.', hidden=True),
|
||||
)
|
||||
return composed(fn)
|
||||
@@ -285,6 +259,17 @@ cli_option = functools.partial(_click_factory_type, attr='option')
|
||||
cli_argument = functools.partial(_click_factory_type, attr='argument')
|
||||
|
||||
|
||||
def adapter_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--adapter-id',
|
||||
default=None,
|
||||
help='Optional name or path for given LoRA adapter',
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
|
||||
)
|
||||
|
||||
|
||||
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--cors/--no-cors',
|
||||
|
||||
@@ -691,7 +691,7 @@ def pretty_print(line: str):
|
||||
elif 'ERROR' in line:
|
||||
caller = termui.error
|
||||
else:
|
||||
caller = caller = functools.partial(termui.echo, fg=None)
|
||||
caller = functools.partial(termui.echo, fg=None)
|
||||
caller(line.strip())
|
||||
|
||||
|
||||
@@ -1026,6 +1026,7 @@ def build_command(
|
||||
'OPENLLM_BACKEND': llm.__llm_backend__,
|
||||
'OPENLLM_SERIALIZATION': llm._serialisation,
|
||||
'OPENLLM_MODEL_ID': llm.model_id,
|
||||
'OPENLLM_ADAPTER_MAP': orjson.dumps(None).decode(),
|
||||
}
|
||||
)
|
||||
if llm.quantise:
|
||||
@@ -1035,13 +1036,8 @@ def build_command(
|
||||
if prompt_template:
|
||||
os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
|
||||
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
|
||||
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
|
||||
try:
|
||||
assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything
|
||||
# FIX: This is a patch for _service_vars injection
|
||||
if 'OPENLLM_ADAPTER_MAP' not in os.environ:
|
||||
os.environ['OPENLLM_ADAPTER_MAP'] = orjson.dumps(None).decode()
|
||||
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({'_type': llm.llm_type, '_framework': llm.__llm_backend__})
|
||||
@@ -1190,14 +1186,8 @@ class ModelItem(t.TypedDict):
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
'--show-available',
|
||||
is_flag=True,
|
||||
default=True,
|
||||
hidden=True,
|
||||
help="Show available models in local store (mutually exclusive with '-o porcelain').",
|
||||
)
|
||||
def models_command(show_available: bool) -> dict[t.LiteralString, ModelItem]:
|
||||
@click.option('--show-available', is_flag=True, default=True, hidden=True)
|
||||
def models_command(**_: t.Any) -> dict[t.LiteralString, ModelItem]:
|
||||
"""List all supported models.
|
||||
|
||||
\b
|
||||
|
||||
Reference in New Issue
Block a user