feat: OpenAI-compatible API (#417)

Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
XunchaoZ
2023-10-07 00:50:03 -04:00
committed by GitHub
parent b43fabfff8
commit 04bb29a264
11 changed files with 296 additions and 8 deletions

View File

@@ -40,6 +40,7 @@ _import_structure: dict[str, list[str]] = {
"playground": [],
"testing": [],
"prompts": ["PromptTemplate"],
"protocol": ["openai"],
"utils": ["infer_auto_class"],
"serialisation": ["ggml", "transformers"],
"cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"],
@@ -72,6 +73,7 @@ if _t.TYPE_CHECKING:
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
from .serialisation import ggml as ggml, transformers as transformers
from .prompts import PromptTemplate as PromptTemplate
from .protocol import openai as openai
from .utils import infer_auto_class as infer_auto_class
try:

View File

@@ -1241,6 +1241,8 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
stop: str | t.Iterable[str] | None = attrs.pop('stop', None)
echo = attrs.pop('echo', False)
stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None)
temperature = attrs.pop('temperature', self.config['temperature'])
top_p = attrs.pop('top_p', self.config['top_p'])
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
request_id: str | None = attrs.pop('request_id', None)
@@ -1254,8 +1256,7 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
for tid in stop_token_ids:
if tid: stop_.add(self.tokenizer.decode(tid))
if self.config['temperature'] <= 1e-5: top_p = 1.0
else: top_p = self.config['top_p']
if temperature <= 1e-5: top_p = 1.0
config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs)
sampling_params = config.to_sampling_config()
@@ -1276,6 +1277,8 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
echo = attrs.pop('echo', False)
stop: str | t.Iterable[str] | None = attrs.pop('stop', None)
stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None)
temperature = attrs.pop('temperature', self.config['temperature'])
top_p = attrs.pop('top_p', self.config['top_p'])
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
request_id: str | None = attrs.pop('request_id', None)
@@ -1289,8 +1292,7 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
for tid in stop_token_ids:
if tid: stop_.add(self.tokenizer.decode(tid))
if self.config['temperature'] <= 1e-5: top_p = 1.0
else: top_p = self.config['top_p']
if temperature <= 1e-5: top_p = 1.0
config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs)
sampling_params = config.to_sampling_config()
async for request_output in t.cast('vllm.AsyncLLMEngine', self.model).generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id):

View File

@@ -69,6 +69,107 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s
else:
return runner.generate_iterator.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, **qa_inputs.llm_config.model_dump())
@svc.api(route='v1/completions',
input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.CompletionRequest(prompt='What is 1+1?', model=runner.llm_type))),
output=bentoml.io.Text())
async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
prompt = input_dict.pop('prompt', None)
if prompt is None: raise ValueError("'prompt' should not be None.")
stream = input_dict.pop('stream', False)
config = {
'max_new_tokens': input_dict.pop('max_tokens', llm_config['max_new_tokens']),
'temperature': input_dict.pop('temperature', llm_config['temperature']),
'top_p': input_dict.pop('top_p', llm_config['top_p']),
'n': input_dict.pop('n', llm_config['n']),
'logprobs': input_dict.pop('logprobs', llm_config['logprobs']),
'echo': input_dict.pop('echo', False),
'stop': input_dict.pop('stop', llm_config['stop']),
'presence_penalty': input_dict.pop('presence_penalty', llm_config['presence_penalty']),
'frequency_penalty': input_dict.pop('frequency_penalty', llm_config['frequency_penalty']),
'best_of': input_dict.pop('best_of', llm_config['best_of']),
}
async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t.AsyncGenerator[str, None]:
async for response in responses:
st = openllm.openai.CompletionResponseStream(choices=[openllm.openai.CompletionTextChoice(text=response, index=0)], model=runner.llm_type) # TODO: logprobs, finish_reason
yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(st)).decode()}\n\n'
yield 'data: [DONE]\n\n'
if stream:
ctx.response.headers['Content-Type'] = 'text/event-stream'
if runner.backend == 'vllm':
responses = runner.vllm_generate_iterator.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config)
else:
responses = runner.generate_iterator.async_stream(prompt, **config)
return stream_response_generator(responses)
else:
ctx.response.headers['Content-Type'] = 'application/json'
if runner.backend == 'vllm':
async for output in runner.vllm_generate.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config):
responses = output
if responses is None: raise ValueError("'responses' should not be None.")
else:
responses = await runner.generate.async_run(prompt, **config)
return orjson.dumps(
openllm.utils.bentoml_cattr.unstructure(
openllm.openai.CompletionResponse(choices=[openllm.openai.CompletionTextChoice(text=response, index=i) for i, response in enumerate(responses)],
model=runner.llm_type) # TODO: logprobs, finish_reason and usage
)).decode()
@svc.api(route='/v1/chat/completions',
input=bentoml.io.JSON.from_sample(
openllm.utils.bentoml_cattr.unstructure(
openllm.openai.ChatCompletionRequest(messages=[{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello!'}],
model=runner.llm_type))),
output=bentoml.io.Text())
async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
prompt = openllm.openai.messages_to_prompt(input_dict['messages'])
stream = input_dict.pop('stream', False)
config = {
'temperature': input_dict.pop('temperature', llm_config['temperature']),
'top_p': input_dict.pop('top_p', llm_config['top_p']),
'n': input_dict.pop('n', llm_config['n']),
'echo': input_dict.pop('echo', False),
'stop': input_dict.pop('stop', llm_config['stop']),
'max_new_tokens': input_dict.pop('max_tokens', llm_config['max_new_tokens']),
'presence_penalty': input_dict.pop('presence_penalty', llm_config['presence_penalty']),
'frequency_penalty': input_dict.pop('frequency_penalty', llm_config['frequency_penalty']),
}
async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t.AsyncGenerator[str, None]:
async for response in responses:
st = openllm.openai.ChatCompletionResponseStream(
choices=[openllm.openai.ChatCompletionStreamChoice(index=0, delta=openllm.openai.Message(role='assistant', content=response), finish_reason=None)], model=runner.llm_type)
yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(st)).decode()}\n\n'
final = openllm.openai.ChatCompletionResponseStream(
choices=[openllm.openai.ChatCompletionStreamChoice(index=0, delta=openllm.openai.Message(role='assistant', content=''), finish_reason='stop')], model=runner.llm_type)
yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(final)).decode()}\n\n'
yield 'data: [DONE]\n\n'
if stream:
ctx.response.headers['Content-Type'] = 'text/event-stream'
if runner.backend == 'vllm':
responses = runner.vllm_generate_iterator.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config)
else:
responses = runner.generate_iterator.async_stream(prompt, **config)
return stream_response_generator(responses)
else:
ctx.response.headers['Content-Type'] = 'application/json'
if runner.backend == 'vllm':
async for output in runner.vllm_generate.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config):
responses = output
if responses is None: raise ValueError("'responses' should not be None.")
else:
responses = await runner.generate.async_run(prompt, **config)
return orjson.dumps(
openllm.utils.bentoml_cattr.unstructure(
openllm.openai.ChatCompletionResponse(choices=[
openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)
],
model=model) # TODO: logprobs, finish_reason and usage
)).decode('utf-8')
@svc.api(route='/v1/metadata',
input=bentoml.io.Text(),
output=bentoml.io.JSON.from_sample({

View File

@@ -4,7 +4,7 @@ import typing as t
import bentoml
import openllm
from openllm._prompt import process_prompt
from openllm_core.prompts import process_prompt
from openllm.utils import generate_labels
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import transformers

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
import openllm
from openllm_core._prompt import process_prompt
from openllm_core.prompts import process_prompt
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE
if t.TYPE_CHECKING: import vllm, transformers

View File

@@ -0,0 +1,19 @@
'''Protocol-related packages for all library integrations.
Currently support OpenAI compatible API.
'''
from __future__ import annotations
import os
import typing as t
from openllm_core.utils import LazyModule
_import_structure: dict[str, list[str]] = {'openai': []}
if t.TYPE_CHECKING:
from . import openai as openai
__lazy = LazyModule(__name__, os.path.abspath('__file__'), _import_structure)
__all__ = __lazy.__all__
__dir__ = __lazy.__dir__
__getattr__ = __lazy.__getattr__

View File

@@ -0,0 +1,124 @@
from __future__ import annotations
import time
import typing as t
import attr
import openllm_core
@attr.define
class CompletionRequest:
prompt: str
model: str = attr.field(default=None)
suffix: t.Optional[str] = attr.field(default=None)
max_tokens: t.Optional[int] = attr.field(default=16)
temperature: t.Optional[float] = attr.field(default=1.0)
top_p: t.Optional[float] = attr.field(default=1)
n: t.Optional[int] = attr.field(default=1)
stream: t.Optional[bool] = attr.field(default=False)
logprobs: t.Optional[int] = attr.field(default=None)
echo: t.Optional[bool] = attr.field(default=False)
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
presence_penalty: t.Optional[float] = attr.field(default=0.0)
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
best_of: t.Optional[int] = attr.field(default=1)
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
user: t.Optional[str] = attr.field(default=None)
@attr.define
class ChatCompletionRequest:
messages: t.List[t.Dict[str, str]]
model: str = attr.field(default=None)
functions: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list))
function_calls: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list))
temperature: t.Optional[float] = attr.field(default=1.0)
top_p: t.Optional[float] = attr.field(default=1)
n: t.Optional[int] = attr.field(default=1)
stream: t.Optional[bool] = attr.field(default=False)
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
max_tokens: t.Optional[int] = attr.field(default=None)
presence_penalty: t.Optional[float] = attr.field(default=0.0)
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
user: t.Optional[str] = attr.field(default=None)
@attr.define
class LogProbs:
text_offset: t.List[int] = attr.field(default=attr.Factory(list))
token_logprobs: t.List[float] = attr.field(default=attr.Factory(list))
tokens: t.List[str] = attr.field(default=attr.Factory(list))
top_logprobs: t.List[t.Dict[str, t.Any]] = attr.field(default=attr.Factory(list))
@attr.define
class CompletionTextChoice:
text: str
index: int
logprobs: LogProbs = attr.field(default=attr.Factory(lambda: LogProbs()))
finish_reason: str = attr.field(default=None)
@attr.define
class Usage:
prompt_tokens: int = attr.field(default=0)
completion_tokens: int = attr.field(default=0)
total_tokens: int = attr.field(default=0)
@attr.define
class CompletionResponse:
choices: t.List[CompletionTextChoice]
model: str
object: str = 'text_completion'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
usage: Usage = attr.field(default=attr.Factory(lambda: Usage()))
@attr.define
class CompletionResponseStream:
choices: t.List[CompletionTextChoice]
model: str
object: str = 'text_completion'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
LiteralRole = t.Literal['system', 'user', 'assistant']
class Message(t.TypedDict):
role: LiteralRole
content: str
@attr.define
class Delta:
role: LiteralRole
content: str
@attr.define
class ChatCompletionChoice:
index: int
message: Message
finish_reason: str = attr.field(default=None)
@attr.define
class ChatCompletionStreamChoice:
index: int
delta: Message
finish_reason: str = attr.field(default=None)
@attr.define
class ChatCompletionResponse:
choices: t.List[ChatCompletionChoice]
model: str
object: str = 'chat.completion'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
usage: Usage = attr.field(default=attr.Factory(lambda: Usage()))
@attr.define
class ChatCompletionResponseStream:
choices: t.List[ChatCompletionStreamChoice]
model: str
object: str = 'chat.completion.chunk'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
def messages_to_prompt(messages: list[Message]) -> str:
formatted = '\n'.join([f"{message['role']}: {message['content']}" for message in messages])
return f'{formatted}\nassistant:'