mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-25 15:57:51 -05:00
feat: OpenAI-compatible API (#417)
Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,7 @@ _import_structure: dict[str, list[str]] = {
|
||||
"playground": [],
|
||||
"testing": [],
|
||||
"prompts": ["PromptTemplate"],
|
||||
"protocol": ["openai"],
|
||||
"utils": ["infer_auto_class"],
|
||||
"serialisation": ["ggml", "transformers"],
|
||||
"cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"],
|
||||
@@ -72,6 +73,7 @@ if _t.TYPE_CHECKING:
|
||||
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
|
||||
from .serialisation import ggml as ggml, transformers as transformers
|
||||
from .prompts import PromptTemplate as PromptTemplate
|
||||
from .protocol import openai as openai
|
||||
from .utils import infer_auto_class as infer_auto_class
|
||||
|
||||
try:
|
||||
|
||||
@@ -1241,6 +1241,8 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
stop: str | t.Iterable[str] | None = attrs.pop('stop', None)
|
||||
echo = attrs.pop('echo', False)
|
||||
stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None)
|
||||
temperature = attrs.pop('temperature', self.config['temperature'])
|
||||
top_p = attrs.pop('top_p', self.config['top_p'])
|
||||
adapter_name = attrs.pop('adapter_name', None)
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
request_id: str | None = attrs.pop('request_id', None)
|
||||
@@ -1254,8 +1256,7 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
for tid in stop_token_ids:
|
||||
if tid: stop_.add(self.tokenizer.decode(tid))
|
||||
|
||||
if self.config['temperature'] <= 1e-5: top_p = 1.0
|
||||
else: top_p = self.config['top_p']
|
||||
if temperature <= 1e-5: top_p = 1.0
|
||||
config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs)
|
||||
sampling_params = config.to_sampling_config()
|
||||
|
||||
@@ -1276,6 +1277,8 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
echo = attrs.pop('echo', False)
|
||||
stop: str | t.Iterable[str] | None = attrs.pop('stop', None)
|
||||
stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None)
|
||||
temperature = attrs.pop('temperature', self.config['temperature'])
|
||||
top_p = attrs.pop('top_p', self.config['top_p'])
|
||||
adapter_name = attrs.pop('adapter_name', None)
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
request_id: str | None = attrs.pop('request_id', None)
|
||||
@@ -1289,8 +1292,7 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
for tid in stop_token_ids:
|
||||
if tid: stop_.add(self.tokenizer.decode(tid))
|
||||
|
||||
if self.config['temperature'] <= 1e-5: top_p = 1.0
|
||||
else: top_p = self.config['top_p']
|
||||
if temperature <= 1e-5: top_p = 1.0
|
||||
config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs)
|
||||
sampling_params = config.to_sampling_config()
|
||||
async for request_output in t.cast('vllm.AsyncLLMEngine', self.model).generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id):
|
||||
|
||||
@@ -69,6 +69,107 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s
|
||||
else:
|
||||
return runner.generate_iterator.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, **qa_inputs.llm_config.model_dump())
|
||||
|
||||
@svc.api(route='v1/completions',
|
||||
input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.CompletionRequest(prompt='What is 1+1?', model=runner.llm_type))),
|
||||
output=bentoml.io.Text())
|
||||
async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
|
||||
prompt = input_dict.pop('prompt', None)
|
||||
if prompt is None: raise ValueError("'prompt' should not be None.")
|
||||
stream = input_dict.pop('stream', False)
|
||||
config = {
|
||||
'max_new_tokens': input_dict.pop('max_tokens', llm_config['max_new_tokens']),
|
||||
'temperature': input_dict.pop('temperature', llm_config['temperature']),
|
||||
'top_p': input_dict.pop('top_p', llm_config['top_p']),
|
||||
'n': input_dict.pop('n', llm_config['n']),
|
||||
'logprobs': input_dict.pop('logprobs', llm_config['logprobs']),
|
||||
'echo': input_dict.pop('echo', False),
|
||||
'stop': input_dict.pop('stop', llm_config['stop']),
|
||||
'presence_penalty': input_dict.pop('presence_penalty', llm_config['presence_penalty']),
|
||||
'frequency_penalty': input_dict.pop('frequency_penalty', llm_config['frequency_penalty']),
|
||||
'best_of': input_dict.pop('best_of', llm_config['best_of']),
|
||||
}
|
||||
|
||||
async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t.AsyncGenerator[str, None]:
|
||||
async for response in responses:
|
||||
st = openllm.openai.CompletionResponseStream(choices=[openllm.openai.CompletionTextChoice(text=response, index=0)], model=runner.llm_type) # TODO: logprobs, finish_reason
|
||||
yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(st)).decode()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
if stream:
|
||||
ctx.response.headers['Content-Type'] = 'text/event-stream'
|
||||
if runner.backend == 'vllm':
|
||||
responses = runner.vllm_generate_iterator.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config)
|
||||
else:
|
||||
responses = runner.generate_iterator.async_stream(prompt, **config)
|
||||
return stream_response_generator(responses)
|
||||
else:
|
||||
ctx.response.headers['Content-Type'] = 'application/json'
|
||||
if runner.backend == 'vllm':
|
||||
async for output in runner.vllm_generate.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config):
|
||||
responses = output
|
||||
if responses is None: raise ValueError("'responses' should not be None.")
|
||||
else:
|
||||
responses = await runner.generate.async_run(prompt, **config)
|
||||
|
||||
return orjson.dumps(
|
||||
openllm.utils.bentoml_cattr.unstructure(
|
||||
openllm.openai.CompletionResponse(choices=[openllm.openai.CompletionTextChoice(text=response, index=i) for i, response in enumerate(responses)],
|
||||
model=runner.llm_type) # TODO: logprobs, finish_reason and usage
|
||||
)).decode()
|
||||
|
||||
@svc.api(route='/v1/chat/completions',
|
||||
input=bentoml.io.JSON.from_sample(
|
||||
openllm.utils.bentoml_cattr.unstructure(
|
||||
openllm.openai.ChatCompletionRequest(messages=[{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': 'Hello!'}],
|
||||
model=runner.llm_type))),
|
||||
output=bentoml.io.Text())
|
||||
async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
|
||||
prompt = openllm.openai.messages_to_prompt(input_dict['messages'])
|
||||
stream = input_dict.pop('stream', False)
|
||||
config = {
|
||||
'temperature': input_dict.pop('temperature', llm_config['temperature']),
|
||||
'top_p': input_dict.pop('top_p', llm_config['top_p']),
|
||||
'n': input_dict.pop('n', llm_config['n']),
|
||||
'echo': input_dict.pop('echo', False),
|
||||
'stop': input_dict.pop('stop', llm_config['stop']),
|
||||
'max_new_tokens': input_dict.pop('max_tokens', llm_config['max_new_tokens']),
|
||||
'presence_penalty': input_dict.pop('presence_penalty', llm_config['presence_penalty']),
|
||||
'frequency_penalty': input_dict.pop('frequency_penalty', llm_config['frequency_penalty']),
|
||||
}
|
||||
|
||||
async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t.AsyncGenerator[str, None]:
|
||||
async for response in responses:
|
||||
st = openllm.openai.ChatCompletionResponseStream(
|
||||
choices=[openllm.openai.ChatCompletionStreamChoice(index=0, delta=openllm.openai.Message(role='assistant', content=response), finish_reason=None)], model=runner.llm_type)
|
||||
yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(st)).decode()}\n\n'
|
||||
final = openllm.openai.ChatCompletionResponseStream(
|
||||
choices=[openllm.openai.ChatCompletionStreamChoice(index=0, delta=openllm.openai.Message(role='assistant', content=''), finish_reason='stop')], model=runner.llm_type)
|
||||
yield f'data: {orjson.dumps(openllm.utils.bentoml_cattr.unstructure(final)).decode()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
if stream:
|
||||
ctx.response.headers['Content-Type'] = 'text/event-stream'
|
||||
if runner.backend == 'vllm':
|
||||
responses = runner.vllm_generate_iterator.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config)
|
||||
else:
|
||||
responses = runner.generate_iterator.async_stream(prompt, **config)
|
||||
return stream_response_generator(responses)
|
||||
else:
|
||||
ctx.response.headers['Content-Type'] = 'application/json'
|
||||
if runner.backend == 'vllm':
|
||||
async for output in runner.vllm_generate.async_stream(prompt, request_id=openllm_core.utils.gen_random_uuid(), **config):
|
||||
responses = output
|
||||
if responses is None: raise ValueError("'responses' should not be None.")
|
||||
else:
|
||||
responses = await runner.generate.async_run(prompt, **config)
|
||||
return orjson.dumps(
|
||||
openllm.utils.bentoml_cattr.unstructure(
|
||||
openllm.openai.ChatCompletionResponse(choices=[
|
||||
openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)
|
||||
],
|
||||
model=model) # TODO: logprobs, finish_reason and usage
|
||||
)).decode('utf-8')
|
||||
|
||||
@svc.api(route='/v1/metadata',
|
||||
input=bentoml.io.Text(),
|
||||
output=bentoml.io.JSON.from_sample({
|
||||
|
||||
@@ -4,7 +4,7 @@ import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from openllm._prompt import process_prompt
|
||||
from openllm_core.prompts import process_prompt
|
||||
from openllm.utils import generate_labels
|
||||
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.prompts import process_prompt
|
||||
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
|
||||
19
openllm-python/src/openllm/protocol/__init__.py
Normal file
19
openllm-python/src/openllm/protocol/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
'''Protocol-related packages for all library integrations.
|
||||
|
||||
Currently support OpenAI compatible API.
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'openai': []}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from . import openai as openai
|
||||
|
||||
__lazy = LazyModule(__name__, os.path.abspath('__file__'), _import_structure)
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
124
openllm-python/src/openllm/protocol/openai.py
Normal file
124
openllm-python/src/openllm/protocol/openai.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
|
||||
import openllm_core
|
||||
|
||||
@attr.define
|
||||
class CompletionRequest:
|
||||
prompt: str
|
||||
model: str = attr.field(default=None)
|
||||
suffix: t.Optional[str] = attr.field(default=None)
|
||||
max_tokens: t.Optional[int] = attr.field(default=16)
|
||||
temperature: t.Optional[float] = attr.field(default=1.0)
|
||||
top_p: t.Optional[float] = attr.field(default=1)
|
||||
n: t.Optional[int] = attr.field(default=1)
|
||||
stream: t.Optional[bool] = attr.field(default=False)
|
||||
logprobs: t.Optional[int] = attr.field(default=None)
|
||||
echo: t.Optional[bool] = attr.field(default=False)
|
||||
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
|
||||
presence_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
best_of: t.Optional[int] = attr.field(default=1)
|
||||
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
|
||||
user: t.Optional[str] = attr.field(default=None)
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionRequest:
|
||||
messages: t.List[t.Dict[str, str]]
|
||||
model: str = attr.field(default=None)
|
||||
functions: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list))
|
||||
function_calls: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list))
|
||||
temperature: t.Optional[float] = attr.field(default=1.0)
|
||||
top_p: t.Optional[float] = attr.field(default=1)
|
||||
n: t.Optional[int] = attr.field(default=1)
|
||||
stream: t.Optional[bool] = attr.field(default=False)
|
||||
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
|
||||
max_tokens: t.Optional[int] = attr.field(default=None)
|
||||
presence_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
|
||||
user: t.Optional[str] = attr.field(default=None)
|
||||
|
||||
@attr.define
|
||||
class LogProbs:
|
||||
text_offset: t.List[int] = attr.field(default=attr.Factory(list))
|
||||
token_logprobs: t.List[float] = attr.field(default=attr.Factory(list))
|
||||
tokens: t.List[str] = attr.field(default=attr.Factory(list))
|
||||
top_logprobs: t.List[t.Dict[str, t.Any]] = attr.field(default=attr.Factory(list))
|
||||
|
||||
@attr.define
|
||||
class CompletionTextChoice:
|
||||
text: str
|
||||
index: int
|
||||
logprobs: LogProbs = attr.field(default=attr.Factory(lambda: LogProbs()))
|
||||
finish_reason: str = attr.field(default=None)
|
||||
|
||||
@attr.define
|
||||
class Usage:
|
||||
prompt_tokens: int = attr.field(default=0)
|
||||
completion_tokens: int = attr.field(default=0)
|
||||
total_tokens: int = attr.field(default=0)
|
||||
|
||||
@attr.define
|
||||
class CompletionResponse:
|
||||
choices: t.List[CompletionTextChoice]
|
||||
model: str
|
||||
object: str = 'text_completion'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
usage: Usage = attr.field(default=attr.Factory(lambda: Usage()))
|
||||
|
||||
@attr.define
|
||||
class CompletionResponseStream:
|
||||
choices: t.List[CompletionTextChoice]
|
||||
model: str
|
||||
object: str = 'text_completion'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
|
||||
LiteralRole = t.Literal['system', 'user', 'assistant']
|
||||
|
||||
class Message(t.TypedDict):
|
||||
role: LiteralRole
|
||||
content: str
|
||||
|
||||
@attr.define
|
||||
class Delta:
|
||||
role: LiteralRole
|
||||
content: str
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionChoice:
|
||||
index: int
|
||||
message: Message
|
||||
finish_reason: str = attr.field(default=None)
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionStreamChoice:
|
||||
index: int
|
||||
delta: Message
|
||||
finish_reason: str = attr.field(default=None)
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponse:
|
||||
choices: t.List[ChatCompletionChoice]
|
||||
model: str
|
||||
object: str = 'chat.completion'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
|
||||
usage: Usage = attr.field(default=attr.Factory(lambda: Usage()))
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponseStream:
|
||||
choices: t.List[ChatCompletionStreamChoice]
|
||||
model: str
|
||||
object: str = 'chat.completion.chunk'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
|
||||
|
||||
def messages_to_prompt(messages: list[Message]) -> str:
|
||||
formatted = '\n'.join([f"{message['role']}: {message['content']}" for message in messages])
|
||||
return f'{formatted}\nassistant:'
|
||||
Reference in New Issue
Block a user