mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-05 23:56:47 -05:00
experimental: Cohere compatible endpoints. (#644)
* feat: add generate endpoint Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: update generation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * fix(cohere): generate endpoints Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: --wip-- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * feat: update testing clients and chat implementation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: disable schemas for easter eggs Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -8,21 +8,28 @@ Each module should implement the following API:
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import typing as t
|
||||
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
from . import hf as hf, openai as openai
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'openai': [], 'hf': []}
|
||||
|
||||
class IntegrationModule(t.Protocol):
|
||||
def mount_to_svc(self, svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service: ...
|
||||
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'openai': [], 'hf': [], 'cohere': []}
|
||||
|
||||
|
||||
def mount_entrypoints(svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service:
|
||||
return openai.mount_to_svc(hf.mount_to_svc(svc, llm), llm)
|
||||
for module_name in _import_structure:
|
||||
module = t.cast(IntegrationModule, importlib.import_module(f'.{module_name}', __name__))
|
||||
svc = module.mount_to_svc(svc, llm)
|
||||
return svc
|
||||
|
||||
|
||||
__lazy = LazyModule(
|
||||
|
||||
@@ -396,7 +396,7 @@ produces:
|
||||
summary: Describes a model offering that can be used with the API.
|
||||
tags:
|
||||
- HF
|
||||
x-bentoml-name: adapters_map
|
||||
x-bentoml-name: hf_adapters
|
||||
responses:
|
||||
200:
|
||||
description: Return list of LoRA adapters.
|
||||
@@ -416,6 +416,65 @@ responses:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Not Found
|
||||
"""
|
||||
COHERE_GENERATE_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: >-
|
||||
Given a prompt, the model will return one or more predicted completions, and
|
||||
can also return the probabilities of alternative tokens at each position.
|
||||
operationId: cohere__generate
|
||||
produces:
|
||||
- application/json
|
||||
tags:
|
||||
- Cohere
|
||||
x-bentoml-name: cohere_generate
|
||||
summary: Creates a completion for the provided prompt and parameters.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CohereGenerateRequest'
|
||||
examples:
|
||||
one-shot:
|
||||
summary: One-shot input example
|
||||
value:
|
||||
prompt: This is a test
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
p: 0.43
|
||||
k: 12
|
||||
num_generations: 2
|
||||
stream: false
|
||||
streaming:
|
||||
summary: Streaming input example
|
||||
value:
|
||||
prompt: This is a test
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
p: 0.43
|
||||
k: 12
|
||||
num_generations: 2
|
||||
stream: true
|
||||
stop_sequences:
|
||||
- "\\n"
|
||||
- "<|endoftext|>"
|
||||
"""
|
||||
COHERE_CHAT_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
description: >-
|
||||
Given a list of messages comprising a conversation, the model will return a response.
|
||||
operationId: cohere__chat
|
||||
produces:
|
||||
- application/json
|
||||
tags:
|
||||
- Cohere
|
||||
x-bentoml-name: cohere_chat
|
||||
summary: Creates a model response for the given chat conversation.
|
||||
"""
|
||||
|
||||
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
|
||||
|
||||
@@ -485,12 +544,15 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
|
||||
|
||||
def get_generator(
|
||||
title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None
|
||||
title: str,
|
||||
components: list[type[AttrsInstance]] | None = None,
|
||||
tags: list[dict[str, t.Any]] | None = None,
|
||||
inject: bool = True,
|
||||
) -> OpenLLMSchemaGenerator:
|
||||
base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION)
|
||||
if components:
|
||||
if components and inject:
|
||||
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
|
||||
if tags is not None and tags:
|
||||
if tags is not None and tags and inject:
|
||||
base_schema['tags'] = tags
|
||||
return OpenLLMSchemaGenerator(base_schema)
|
||||
|
||||
|
||||
317
openllm-python/src/openllm/entrypoints/cohere.py
Normal file
317
openllm-python/src/openllm/entrypoints/cohere.py
Normal file
@@ -0,0 +1,317 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
import typing as t
|
||||
from http import HTTPStatus
|
||||
|
||||
import orjson
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from openllm_core.utils import converter, gen_random_uuid
|
||||
|
||||
from ._openapi import append_schemas, get_generator
|
||||
|
||||
# from ._openapi import add_schema_definitions
|
||||
from ..protocol.cohere import (
|
||||
Chat,
|
||||
ChatStreamEnd,
|
||||
ChatStreamStart,
|
||||
ChatStreamTextGeneration,
|
||||
CohereChatRequest,
|
||||
CohereErrorResponse,
|
||||
CohereGenerateRequest,
|
||||
Generation,
|
||||
Generations,
|
||||
StreamingGenerations,
|
||||
StreamingText,
|
||||
)
|
||||
|
||||
schemas = get_generator(
|
||||
'cohere',
|
||||
components=[
|
||||
CohereChatRequest,
|
||||
CohereErrorResponse,
|
||||
CohereGenerateRequest,
|
||||
Generation,
|
||||
Generations,
|
||||
StreamingGenerations,
|
||||
StreamingText,
|
||||
Chat,
|
||||
ChatStreamStart,
|
||||
ChatStreamEnd,
|
||||
ChatStreamTextGeneration,
|
||||
],
|
||||
tags=[
|
||||
{
|
||||
'name': 'Cohere',
|
||||
'description': 'Cohere compatible API. Currently support /generate, /chat',
|
||||
'externalDocs': 'https://docs.cohere.com/docs/the-cohere-platform',
|
||||
}
|
||||
],
|
||||
inject=False,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import AttrsInstance
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
|
||||
def jsonify_attr(obj: AttrsInstance) -> str:
|
||||
return json.dumps(converter.unstructure(obj))
|
||||
|
||||
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
return JSONResponse(converter.unstructure(CohereErrorResponse(text=message)), status_code=status_code.value)
|
||||
|
||||
|
||||
async def check_model(request: CohereGenerateRequest | CohereChatRequest, model: str) -> JSONResponse | None:
|
||||
if request.model is None or request.model == model:
|
||||
return None
|
||||
return error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see current running models.",
|
||||
)
|
||||
|
||||
|
||||
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route(
|
||||
'/v1/generate',
|
||||
endpoint=functools.partial(cohere_generate, llm=llm),
|
||||
name='cohere_generate',
|
||||
methods=['POST'],
|
||||
include_in_schema=False,
|
||||
),
|
||||
Route(
|
||||
'/v1/chat',
|
||||
endpoint=functools.partial(cohere_chat, llm=llm),
|
||||
name='cohere_chat',
|
||||
methods=['POST'],
|
||||
include_in_schema=False,
|
||||
),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
|
||||
],
|
||||
)
|
||||
mount_path = '/cohere'
|
||||
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
|
||||
|
||||
|
||||
# @add_schema_definitions
|
||||
async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), CohereGenerateRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received generate request: %s', request)
|
||||
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
request_id = gen_random_uuid('cohere-generate')
|
||||
config = llm.config.with_request(request)
|
||||
|
||||
if request.prompt_vars is not None:
|
||||
prompt = request.prompt.format(**request.prompt_vars)
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
# TODO: support end_sequences, stop_sequences, logit_bias, return_likelihoods, truncate
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, stop=request.stop_sequences, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_response_json(index: int, text: str, is_finished: bool) -> str:
|
||||
return f'{jsonify_attr(StreamingText(index=index, text=text, is_finished=is_finished))}\n'
|
||||
|
||||
async def generate_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
yield create_stream_response_json(index=output.index, text=output.text, is_finished=output.finish_reason)
|
||||
|
||||
try:
|
||||
# streaming case
|
||||
if request.stream:
|
||||
return StreamingResponse(generate_stream_generator(), media_type='text/event-stream')
|
||||
# None-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[list[str]] = [[]] * config['num_generations']
|
||||
token_ids: list[list[int]] = [[]] * config['num_generations']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
for output in final_result.outputs
|
||||
]
|
||||
)
|
||||
response = Generations(
|
||||
id=request_id,
|
||||
generations=[
|
||||
Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason)
|
||||
for output in final_result.outputs
|
||||
],
|
||||
)
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
def _transpile_cohere_chat_messages(request: CohereChatRequest) -> list[dict[str, str]]:
|
||||
def convert_role(role):
|
||||
return {'User': 'user', 'Chatbot': 'assistant'}[role]
|
||||
|
||||
chat_history = request.chat_history
|
||||
if chat_history:
|
||||
messages = [{'role': convert_role(msg['role']), 'content': msg['message']} for msg in chat_history]
|
||||
else:
|
||||
messages = []
|
||||
messages.append({'role': 'user', 'content': request.message})
|
||||
return messages
|
||||
|
||||
|
||||
# @add_schema_definitions
|
||||
async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), CohereChatRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received chat completion request: %s', request)
|
||||
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
request_id = gen_random_uuid('cohere-chat')
|
||||
prompt: str = llm.tokenizer.apply_chat_template(
|
||||
_transpile_cohere_chat_messages(request), tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.with_request(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_generation_json(index: int, text: str, is_finished: bool) -> str:
|
||||
return f'{jsonify_attr(ChatStreamTextGeneration(index=index, text=text, is_finished=is_finished))}\n'
|
||||
|
||||
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
texts: list[str] = []
|
||||
token_ids: list[int] = []
|
||||
yield f'{jsonify_attr(ChatStreamStart(is_finished=False, index=0, generation_id=request_id))}\n'
|
||||
|
||||
it = None
|
||||
async for res in result_generator:
|
||||
yield create_stream_generation_json(index=res.outputs[0].index, text=res.outputs[0].text, is_finished=False)
|
||||
texts.append(res.outputs[0].text)
|
||||
token_ids.extend(res.outputs[0].token_ids)
|
||||
it = res
|
||||
|
||||
if it is None:
|
||||
raise ValueError('No response from model.')
|
||||
num_prompt_tokens = len(t.cast(t.List[int], it.prompt_token_ids))
|
||||
num_response_tokens = len(token_ids)
|
||||
total_tokens = num_prompt_tokens + num_response_tokens
|
||||
|
||||
json_str = jsonify_attr(
|
||||
ChatStreamEnd(
|
||||
is_finished=True,
|
||||
finish_reason='COMPLETE',
|
||||
index=0,
|
||||
response=Chat(
|
||||
response_id=request_id,
|
||||
message=request.message,
|
||||
text=''.join(texts),
|
||||
prompt=prompt,
|
||||
chat_history=request.chat_history,
|
||||
token_count={
|
||||
'prompt_tokens': num_prompt_tokens,
|
||||
'response_tokens': num_response_tokens,
|
||||
'total_tokens': total_tokens,
|
||||
},
|
||||
),
|
||||
)
|
||||
)
|
||||
yield f'{json_str}\n'
|
||||
|
||||
try:
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[str] = []
|
||||
token_ids: list[int] = []
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
texts.append(res.outputs[0].text)
|
||||
token_ids.extend(res.outputs[0].token_ids)
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[final_result.outputs[0].with_options(text=''.join(texts), token_ids=token_ids)]
|
||||
)
|
||||
num_prompt_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids))
|
||||
num_response_tokens = len(token_ids)
|
||||
total_tokens = num_prompt_tokens + num_response_tokens
|
||||
|
||||
response = Chat(
|
||||
response_id=request_id,
|
||||
message=request.message,
|
||||
text=''.join(texts),
|
||||
prompt=prompt,
|
||||
chat_history=request.chat_history,
|
||||
token_count={
|
||||
'prompt_tokens': num_prompt_tokens,
|
||||
'response_tokens': num_response_tokens,
|
||||
'total_tokens': total_tokens,
|
||||
},
|
||||
)
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
def openapi_schema(req: Request) -> Response:
|
||||
return schemas.OpenAPIResponse(req)
|
||||
@@ -48,9 +48,8 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
|
||||
],
|
||||
)
|
||||
mount_path = '/hf'
|
||||
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, generated_schema, tags_order='append')
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
|
||||
|
||||
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
|
||||
@@ -119,12 +119,12 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
|
||||
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
|
||||
Route('/completions', functools.partial(completions, llm=llm), methods=['POST']),
|
||||
Route('/chat/completions', functools.partial(chat_completions, llm=llm), methods=['POST']),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
|
||||
],
|
||||
)
|
||||
mount_path = '/v1'
|
||||
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, generated_schema)
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path))
|
||||
|
||||
|
||||
# GET /v1/models
|
||||
@@ -157,7 +157,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.with_openai_request(request)
|
||||
config = llm.config.with_request(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
@@ -287,7 +287,7 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('cmpl')
|
||||
created_time = int(time.monotonic())
|
||||
config = llm.config.with_openai_request(request)
|
||||
config = llm.config.with_request(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
@@ -398,3 +398,7 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
def openapi_schema(req: Request) -> Response:
|
||||
return schemas.OpenAPIResponse(req)
|
||||
|
||||
154
openllm-python/src/openllm/protocol/cohere.py
Normal file
154
openllm-python/src/openllm/protocol/cohere.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from enum import Enum
|
||||
|
||||
import attr
|
||||
|
||||
from openllm_core.utils import converter
|
||||
|
||||
|
||||
@attr.define
|
||||
class CohereErrorResponse:
|
||||
text: str
|
||||
|
||||
|
||||
converter.register_unstructure_hook(CohereErrorResponse, lambda obj: obj.text)
|
||||
|
||||
|
||||
@attr.define
|
||||
class CohereGenerateRequest:
|
||||
prompt: str
|
||||
prompt_vars: t.Optional[t.Dict[str, t.Any]] = None
|
||||
model: t.Optional[str] = None
|
||||
preset: t.Optional[str] = None
|
||||
num_generations: t.Optional[int] = None
|
||||
max_tokens: t.Optional[int] = None
|
||||
temperature: t.Optional[float] = None
|
||||
k: t.Optional[int] = None
|
||||
p: t.Optional[float] = None
|
||||
frequency_penalty: t.Optional[float] = None
|
||||
presence_penalty: t.Optional[float] = None
|
||||
end_sequences: t.Optional[t.List[str]] = None
|
||||
stop_sequences: t.Optional[t.List[str]] = None
|
||||
return_likelihoods: t.Optional[t.Literal['GENERATION', 'ALL', 'NONE']] = None
|
||||
truncate: t.Optional[str] = None
|
||||
logit_bias: t.Optional[t.Dict[int, float]] = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@attr.define
|
||||
class TokenLikelihood: # pretty sure this is similar to token_logprobs
|
||||
token: str
|
||||
likelihood: float
|
||||
|
||||
|
||||
@attr.define
|
||||
class Generation:
|
||||
id: str
|
||||
text: str
|
||||
prompt: str
|
||||
likelihood: t.Optional[float] = None
|
||||
token_likelihoods: t.List[TokenLikelihood] = attr.field(factory=list)
|
||||
finish_reason: t.Optional[str] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class Generations:
|
||||
id: str
|
||||
generations: t.List[Generation]
|
||||
meta: t.Optional[t.Dict[str, t.Any]] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class StreamingText:
|
||||
index: int
|
||||
text: str
|
||||
is_finished: bool
|
||||
|
||||
|
||||
@attr.define
|
||||
class StreamingGenerations:
|
||||
id: str
|
||||
generations: Generations
|
||||
texts: t.List[str]
|
||||
meta: t.Optional[t.Dict[str, t.Any]] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class CohereChatRequest:
|
||||
message: str
|
||||
conversation_id: t.Optional[str] = ''
|
||||
model: t.Optional[str] = None
|
||||
return_chat_history: t.Optional[bool] = False
|
||||
return_prompt: t.Optional[bool] = False
|
||||
return_preamble: t.Optional[bool] = False
|
||||
chat_history: t.Optional[t.List[t.Dict[str, str]]] = None
|
||||
preamble_override: t.Optional[str] = None
|
||||
user_name: t.Optional[str] = None
|
||||
temperature: t.Optional[float] = 0.8
|
||||
max_tokens: t.Optional[int] = None
|
||||
stream: t.Optional[bool] = False
|
||||
p: t.Optional[float] = None
|
||||
k: t.Optional[float] = None
|
||||
logit_bias: t.Optional[t.Dict[int, float]] = None
|
||||
search_queries_only: t.Optional[bool] = None
|
||||
documents: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
citation_quality: t.Optional[str] = None
|
||||
prompt_truncation: t.Optional[str] = None
|
||||
connectors: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
|
||||
|
||||
class StreamEvent(str, Enum):
|
||||
STREAM_START = 'stream-start'
|
||||
TEXT_GENERATION = 'text-generation'
|
||||
STREAM_END = 'stream-end'
|
||||
# TODO: The following are yet to be implemented
|
||||
SEARCH_QUERIES_GENERATION = 'search-queries-generation'
|
||||
SEARCH_RESULTS = 'search-results'
|
||||
CITATION_GENERATION = 'citation-generation'
|
||||
|
||||
|
||||
@attr.define
|
||||
class Chat:
|
||||
response_id: str
|
||||
message: str
|
||||
text: str
|
||||
generation_id: t.Optional[str] = None
|
||||
conversation_id: t.Optional[str] = None
|
||||
meta: t.Optional[t.Dict[str, t.Any]] = None
|
||||
prompt: t.Optional[str] = None
|
||||
chat_history: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
preamble: t.Optional[str] = None
|
||||
token_count: t.Optional[t.Dict[str, int]] = None
|
||||
is_search_required: t.Optional[bool] = None
|
||||
citations: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
documents: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
search_results: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
search_queries: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatStreamResponse:
|
||||
is_finished: bool
|
||||
event_type: StreamEvent
|
||||
index: int
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatStreamStart(ChatStreamResponse):
|
||||
generation_id: str
|
||||
conversation_id: t.Optional[str] = None
|
||||
event_type: StreamEvent = StreamEvent.STREAM_START
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatStreamTextGeneration(ChatStreamResponse):
|
||||
text: str
|
||||
event_type: StreamEvent = StreamEvent.TEXT_GENERATION
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatStreamEnd(ChatStreamResponse):
|
||||
finish_reason: str
|
||||
response: Chat
|
||||
event_type: StreamEvent = StreamEvent.STREAM_END
|
||||
Reference in New Issue
Block a user