experimental: Cohere compatible endpoints. (#644)

* feat: add generate endpoint

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: update generation

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* fix(cohere): generate endpoints

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: --wip--

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* feat: update testing clients and chat implementation

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: disable schemas for easter eggs

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-14 01:07:43 -05:00
committed by GitHub
parent 0bf6ec7537
commit b0ab8ccdf6
8 changed files with 638 additions and 16 deletions

View File

@@ -8,21 +8,28 @@ Each module should implement the following API:
"""
from __future__ import annotations
import importlib
import typing as t
from openllm_core.utils import LazyModule
from . import hf as hf, openai as openai
if t.TYPE_CHECKING:
import bentoml
import openllm
_import_structure: dict[str, list[str]] = {'openai': [], 'hf': []}
class IntegrationModule(t.Protocol):
def mount_to_svc(self, svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service: ...
_import_structure: dict[str, list[str]] = {'openai': [], 'hf': [], 'cohere': []}
def mount_entrypoints(svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service:
return openai.mount_to_svc(hf.mount_to_svc(svc, llm), llm)
for module_name in _import_structure:
module = t.cast(IntegrationModule, importlib.import_module(f'.{module_name}', __name__))
svc = module.mount_to_svc(svc, llm)
return svc
__lazy = LazyModule(

View File

@@ -396,7 +396,7 @@ produces:
summary: Describes a model offering that can be used with the API.
tags:
- HF
x-bentoml-name: adapters_map
x-bentoml-name: hf_adapters
responses:
200:
description: Return list of LoRA adapters.
@@ -416,6 +416,65 @@ responses:
$ref: '#/components/schemas/HFErrorResponse'
description: Not Found
"""
COHERE_GENERATE_SCHEMA = """\
---
consumes:
- application/json
description: >-
Given a prompt, the model will return one or more predicted completions, and
can also return the probabilities of alternative tokens at each position.
operationId: cohere__generate
produces:
- application/json
tags:
- Cohere
x-bentoml-name: cohere_generate
summary: Creates a completion for the provided prompt and parameters.
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/CohereGenerateRequest'
examples:
one-shot:
summary: One-shot input example
value:
prompt: This is a test
max_tokens: 256
temperature: 0.7
p: 0.43
k: 12
num_generations: 2
stream: false
streaming:
summary: Streaming input example
value:
prompt: This is a test
max_tokens: 256
temperature: 0.7
p: 0.43
k: 12
num_generations: 2
stream: true
stop_sequences:
- "\\n"
- "<|endoftext|>"
"""
COHERE_CHAT_SCHEMA = """\
---
consumes:
- application/json
description: >-
Given a list of messages comprising a conversation, the model will return a response.
operationId: cohere__chat
produces:
- application/json
tags:
- Cohere
x-bentoml-name: cohere_chat
summary: Creates a model response for the given chat conversation.
"""
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
@@ -485,12 +544,15 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
def get_generator(
title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None
title: str,
components: list[type[AttrsInstance]] | None = None,
tags: list[dict[str, t.Any]] | None = None,
inject: bool = True,
) -> OpenLLMSchemaGenerator:
base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION)
if components:
if components and inject:
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
if tags is not None and tags:
if tags is not None and tags and inject:
base_schema['tags'] = tags
return OpenLLMSchemaGenerator(base_schema)

View File

@@ -0,0 +1,317 @@
from __future__ import annotations
import functools
import json
import logging
import traceback
import typing as t
from http import HTTPStatus
import orjson
from starlette.applications import Starlette
from starlette.responses import JSONResponse, StreamingResponse
from starlette.routing import Route
from openllm_core.utils import converter, gen_random_uuid
from ._openapi import append_schemas, get_generator
# from ._openapi import add_schema_definitions
from ..protocol.cohere import (
Chat,
ChatStreamEnd,
ChatStreamStart,
ChatStreamTextGeneration,
CohereChatRequest,
CohereErrorResponse,
CohereGenerateRequest,
Generation,
Generations,
StreamingGenerations,
StreamingText,
)
schemas = get_generator(
'cohere',
components=[
CohereChatRequest,
CohereErrorResponse,
CohereGenerateRequest,
Generation,
Generations,
StreamingGenerations,
StreamingText,
Chat,
ChatStreamStart,
ChatStreamEnd,
ChatStreamTextGeneration,
],
tags=[
{
'name': 'Cohere',
'description': 'Cohere compatible API. Currently support /generate, /chat',
'externalDocs': 'https://docs.cohere.com/docs/the-cohere-platform',
}
],
inject=False,
)
logger = logging.getLogger(__name__)
if t.TYPE_CHECKING:
from attr import AttrsInstance
from starlette.requests import Request
from starlette.responses import Response
import bentoml
import openllm
from openllm_core._schemas import GenerationOutput
from openllm_core._typing_compat import M, T
def jsonify_attr(obj: AttrsInstance) -> str:
return json.dumps(converter.unstructure(obj))
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
return JSONResponse(converter.unstructure(CohereErrorResponse(text=message)), status_code=status_code.value)
async def check_model(request: CohereGenerateRequest | CohereChatRequest, model: str) -> JSONResponse | None:
if request.model is None or request.model == model:
return None
return error_response(
HTTPStatus.NOT_FOUND,
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see current running models.",
)
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
app = Starlette(
debug=True,
routes=[
Route(
'/v1/generate',
endpoint=functools.partial(cohere_generate, llm=llm),
name='cohere_generate',
methods=['POST'],
include_in_schema=False,
),
Route(
'/v1/chat',
endpoint=functools.partial(cohere_chat, llm=llm),
name='cohere_chat',
methods=['POST'],
include_in_schema=False,
),
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
],
)
mount_path = '/cohere'
svc.mount_asgi_app(app, path=mount_path)
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
# @add_schema_definitions
async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
json_str = await req.body()
try:
request = converter.structure(orjson.loads(json_str), CohereGenerateRequest)
except orjson.JSONDecodeError as err:
logger.debug('Sent body: %s', json_str)
logger.error('Invalid JSON input received: %s', err)
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
logger.debug('Received generate request: %s', request)
err_check = await check_model(request, llm.llm_type)
if err_check is not None:
return err_check
request_id = gen_random_uuid('cohere-generate')
config = llm.config.with_request(request)
if request.prompt_vars is not None:
prompt = request.prompt.format(**request.prompt_vars)
else:
prompt = request.prompt
# TODO: support end_sequences, stop_sequences, logit_bias, return_likelihoods, truncate
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, stop=request.stop_sequences, **config)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def create_stream_response_json(index: int, text: str, is_finished: bool) -> str:
return f'{jsonify_attr(StreamingText(index=index, text=text, is_finished=is_finished))}\n'
async def generate_stream_generator() -> t.AsyncGenerator[str, None]:
async for res in result_generator:
for output in res.outputs:
yield create_stream_response_json(index=output.index, text=output.text, is_finished=output.finish_reason)
try:
# streaming case
if request.stream:
return StreamingResponse(generate_stream_generator(), media_type='text/event-stream')
# None-streaming case
final_result: GenerationOutput | None = None
texts: list[list[str]] = [[]] * config['num_generations']
token_ids: list[list[int]] = [[]] * config['num_generations']
async for res in result_generator:
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
for output in res.outputs:
texts[output.index].append(output.text)
token_ids[output.index].extend(output.token_ids)
final_result = res
if final_result is None:
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
final_result = final_result.with_options(
outputs=[
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
for output in final_result.outputs
]
)
response = Generations(
id=request_id,
generations=[
Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason)
for output in final_result.outputs
],
)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def _transpile_cohere_chat_messages(request: CohereChatRequest) -> list[dict[str, str]]:
def convert_role(role):
return {'User': 'user', 'Chatbot': 'assistant'}[role]
chat_history = request.chat_history
if chat_history:
messages = [{'role': convert_role(msg['role']), 'content': msg['message']} for msg in chat_history]
else:
messages = []
messages.append({'role': 'user', 'content': request.message})
return messages
# @add_schema_definitions
async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
json_str = await req.body()
try:
request = converter.structure(orjson.loads(json_str), CohereChatRequest)
except orjson.JSONDecodeError as err:
logger.debug('Sent body: %s', json_str)
logger.error('Invalid JSON input received: %s', err)
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
logger.debug('Received chat completion request: %s', request)
err_check = await check_model(request, llm.llm_type)
if err_check is not None:
return err_check
request_id = gen_random_uuid('cohere-chat')
prompt: str = llm.tokenizer.apply_chat_template(
_transpile_cohere_chat_messages(request), tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
)
logger.debug('Prompt: %r', prompt)
config = llm.config.with_request(request)
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def create_stream_generation_json(index: int, text: str, is_finished: bool) -> str:
return f'{jsonify_attr(ChatStreamTextGeneration(index=index, text=text, is_finished=is_finished))}\n'
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
texts: list[str] = []
token_ids: list[int] = []
yield f'{jsonify_attr(ChatStreamStart(is_finished=False, index=0, generation_id=request_id))}\n'
it = None
async for res in result_generator:
yield create_stream_generation_json(index=res.outputs[0].index, text=res.outputs[0].text, is_finished=False)
texts.append(res.outputs[0].text)
token_ids.extend(res.outputs[0].token_ids)
it = res
if it is None:
raise ValueError('No response from model.')
num_prompt_tokens = len(t.cast(t.List[int], it.prompt_token_ids))
num_response_tokens = len(token_ids)
total_tokens = num_prompt_tokens + num_response_tokens
json_str = jsonify_attr(
ChatStreamEnd(
is_finished=True,
finish_reason='COMPLETE',
index=0,
response=Chat(
response_id=request_id,
message=request.message,
text=''.join(texts),
prompt=prompt,
chat_history=request.chat_history,
token_count={
'prompt_tokens': num_prompt_tokens,
'response_tokens': num_response_tokens,
'total_tokens': total_tokens,
},
),
)
)
yield f'{json_str}\n'
try:
if request.stream:
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result: GenerationOutput | None = None
texts: list[str] = []
token_ids: list[int] = []
async for res in result_generator:
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
texts.append(res.outputs[0].text)
token_ids.extend(res.outputs[0].token_ids)
final_result = res
if final_result is None:
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
final_result = final_result.with_options(
outputs=[final_result.outputs[0].with_options(text=''.join(texts), token_ids=token_ids)]
)
num_prompt_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids))
num_response_tokens = len(token_ids)
total_tokens = num_prompt_tokens + num_response_tokens
response = Chat(
response_id=request_id,
message=request.message,
text=''.join(texts),
prompt=prompt,
chat_history=request.chat_history,
token_count={
'prompt_tokens': num_prompt_tokens,
'response_tokens': num_response_tokens,
'total_tokens': total_tokens,
},
)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def openapi_schema(req: Request) -> Response:
return schemas.OpenAPIResponse(req)

View File

@@ -48,9 +48,8 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
],
)
mount_path = '/hf'
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
svc.mount_asgi_app(app, path=mount_path)
return append_schemas(svc, generated_schema, tags_order='append')
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:

View File

@@ -119,12 +119,12 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
Route('/completions', functools.partial(completions, llm=llm), methods=['POST']),
Route('/chat/completions', functools.partial(chat_completions, llm=llm), methods=['POST']),
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
],
)
mount_path = '/v1'
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
svc.mount_asgi_app(app, path=mount_path)
return append_schemas(svc, generated_schema)
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path))
# GET /v1/models
@@ -157,7 +157,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
)
logger.debug('Prompt: %r', prompt)
config = llm.config.with_openai_request(request)
config = llm.config.with_request(request)
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
@@ -287,7 +287,7 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
model_name, request_id = request.model, gen_random_uuid('cmpl')
created_time = int(time.monotonic())
config = llm.config.with_openai_request(request)
config = llm.config.with_request(request)
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
@@ -398,3 +398,7 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def openapi_schema(req: Request) -> Response:
return schemas.OpenAPIResponse(req)

View File

@@ -0,0 +1,154 @@
from __future__ import annotations
import typing as t
from enum import Enum
import attr
from openllm_core.utils import converter
@attr.define
class CohereErrorResponse:
text: str
converter.register_unstructure_hook(CohereErrorResponse, lambda obj: obj.text)
@attr.define
class CohereGenerateRequest:
prompt: str
prompt_vars: t.Optional[t.Dict[str, t.Any]] = None
model: t.Optional[str] = None
preset: t.Optional[str] = None
num_generations: t.Optional[int] = None
max_tokens: t.Optional[int] = None
temperature: t.Optional[float] = None
k: t.Optional[int] = None
p: t.Optional[float] = None
frequency_penalty: t.Optional[float] = None
presence_penalty: t.Optional[float] = None
end_sequences: t.Optional[t.List[str]] = None
stop_sequences: t.Optional[t.List[str]] = None
return_likelihoods: t.Optional[t.Literal['GENERATION', 'ALL', 'NONE']] = None
truncate: t.Optional[str] = None
logit_bias: t.Optional[t.Dict[int, float]] = None
stream: bool = False
@attr.define
class TokenLikelihood: # pretty sure this is similar to token_logprobs
token: str
likelihood: float
@attr.define
class Generation:
id: str
text: str
prompt: str
likelihood: t.Optional[float] = None
token_likelihoods: t.List[TokenLikelihood] = attr.field(factory=list)
finish_reason: t.Optional[str] = None
@attr.define
class Generations:
id: str
generations: t.List[Generation]
meta: t.Optional[t.Dict[str, t.Any]] = None
@attr.define
class StreamingText:
index: int
text: str
is_finished: bool
@attr.define
class StreamingGenerations:
id: str
generations: Generations
texts: t.List[str]
meta: t.Optional[t.Dict[str, t.Any]] = None
@attr.define
class CohereChatRequest:
message: str
conversation_id: t.Optional[str] = ''
model: t.Optional[str] = None
return_chat_history: t.Optional[bool] = False
return_prompt: t.Optional[bool] = False
return_preamble: t.Optional[bool] = False
chat_history: t.Optional[t.List[t.Dict[str, str]]] = None
preamble_override: t.Optional[str] = None
user_name: t.Optional[str] = None
temperature: t.Optional[float] = 0.8
max_tokens: t.Optional[int] = None
stream: t.Optional[bool] = False
p: t.Optional[float] = None
k: t.Optional[float] = None
logit_bias: t.Optional[t.Dict[int, float]] = None
search_queries_only: t.Optional[bool] = None
documents: t.Optional[t.List[t.Dict[str, t.Any]]] = None
citation_quality: t.Optional[str] = None
prompt_truncation: t.Optional[str] = None
connectors: t.Optional[t.List[t.Dict[str, t.Any]]] = None
class StreamEvent(str, Enum):
STREAM_START = 'stream-start'
TEXT_GENERATION = 'text-generation'
STREAM_END = 'stream-end'
# TODO: The following are yet to be implemented
SEARCH_QUERIES_GENERATION = 'search-queries-generation'
SEARCH_RESULTS = 'search-results'
CITATION_GENERATION = 'citation-generation'
@attr.define
class Chat:
response_id: str
message: str
text: str
generation_id: t.Optional[str] = None
conversation_id: t.Optional[str] = None
meta: t.Optional[t.Dict[str, t.Any]] = None
prompt: t.Optional[str] = None
chat_history: t.Optional[t.List[t.Dict[str, t.Any]]] = None
preamble: t.Optional[str] = None
token_count: t.Optional[t.Dict[str, int]] = None
is_search_required: t.Optional[bool] = None
citations: t.Optional[t.List[t.Dict[str, t.Any]]] = None
documents: t.Optional[t.List[t.Dict[str, t.Any]]] = None
search_results: t.Optional[t.List[t.Dict[str, t.Any]]] = None
search_queries: t.Optional[t.List[t.Dict[str, t.Any]]] = None
@attr.define
class ChatStreamResponse:
is_finished: bool
event_type: StreamEvent
index: int
@attr.define
class ChatStreamStart(ChatStreamResponse):
generation_id: str
conversation_id: t.Optional[str] = None
event_type: StreamEvent = StreamEvent.STREAM_START
@attr.define
class ChatStreamTextGeneration(ChatStreamResponse):
text: str
event_type: StreamEvent = StreamEvent.TEXT_GENERATION
@attr.define
class ChatStreamEnd(ChatStreamResponse):
finish_reason: str
response: Chat
event_type: StreamEvent = StreamEvent.STREAM_END