feat(openai): dynamic model_type registration (#704)

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-20 00:13:45 -05:00
committed by GitHub
parent 6505abdb44
commit 513c08ccda
3 changed files with 29 additions and 9 deletions

View File

@@ -35,7 +35,7 @@ responses:
example:
object: 'list'
data:
- id: meta-llama--Llama-2-13b-chat-hf
- id: __model_id__
object: model
created: 1686935002
owned_by: 'na'
@@ -69,7 +69,7 @@ requestBody:
content: You are a helpful assistant.
- role: user
content: Hello, I'm looking for a chatbot that can help me with my work.
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
@@ -83,7 +83,7 @@ requestBody:
content: You are a helpful assistant.
- role: user
content: Hello, I'm looking for a chatbot that can help me with my work.
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
@@ -206,7 +206,7 @@ requestBody:
summary: One-shot input example
value:
prompt: This is a test
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
logprobs: 1
@@ -217,7 +217,7 @@ requestBody:
summary: Streaming input example
value:
prompt: This is a test
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
@@ -472,6 +472,12 @@ summary: Creates a model response for the given chat conversation.
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
def apply_schema(func, **attrs):
for k, v in attrs.items():
func.__doc__ = func.__doc__.replace(k, v)
return func
def add_schema_definitions(func):
append_str = _SCHEMAS.get(func.__name__.lower(), '')
if not append_str:

View File

@@ -15,6 +15,7 @@ class OpenLLMSchemaGenerator:
def get_schema(self, routes: list[BaseRoute], mount_path: Optional[str] = ...) -> Dict[str, Any]: ...
def parse_docstring(self, func_or_method: Callable[P, Any]) -> Dict[str, Any]: ...
def apply_schema(func: Callable[P, Any], **attrs: Any) -> Callable[P, Any]: ...
def add_schema_definitions(func: Callable[P, Any]) -> Callable[P, Any]: ...
def append_schemas(
svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...

View File

@@ -13,7 +13,7 @@ from starlette.routing import Route
from openllm_core._schemas import SampleLogprobs
from openllm_core.utils import converter, gen_random_uuid
from ._openapi import add_schema_definitions, append_schemas, get_generator
from ._openapi import add_schema_definitions, append_schemas, apply_schema, get_generator
from ..protocol.openai import (
ChatCompletionRequest,
ChatCompletionResponse,
@@ -100,12 +100,25 @@ def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm):
def mount_to_svc(svc, llm):
list_models.__doc__ = list_models.__doc__.replace('__model_id__', llm.llm_type)
completions.__doc__ = completions.__doc__.replace('__model_id__', llm.llm_type)
chat_completions.__doc__ = chat_completions.__doc__.replace('__model_id__', llm.llm_type)
app = Starlette(
debug=True,
routes=[
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
Route('/completions', functools.partial(completions, llm=llm), methods=['POST']),
Route('/chat/completions', functools.partial(chat_completions, llm=llm), methods=['POST']),
Route(
'/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']
),
Route(
'/completions',
functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm),
methods=['POST'],
),
Route(
'/chat/completions',
functools.partial(apply_schema(chat_completions, __model_id__=llm.llm_type), llm=llm),
methods=['POST'],
),
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
],
)