perf: unify LLM interface (#518)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-06 20:39:43 -05:00
committed by GitHub
parent f2639879af
commit e2029c934b
136 changed files with 9646 additions and 11244 deletions

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
import typing as t
import attr
@attr.define
class AgentRequest:
inputs: str
parameters: t.Dict[str, t.Any]
@attr.define
class AgentResponse:
generated_text: str
@attr.define
class AgentErrorResponse:
error_code: int
message: str

View File

@@ -6,7 +6,15 @@ import attr
import openllm_core
from openllm import _conversation
from openllm_core.utils import converter
@attr.define
class ErrorResponse:
message: str
type: str
object: str = 'error'
param: t.Optional[str] = None
code: t.Optional[str] = None
@attr.define
class CompletionRequest:
@@ -15,7 +23,7 @@ class CompletionRequest:
suffix: t.Optional[str] = attr.field(default=None)
max_tokens: t.Optional[int] = attr.field(default=16)
temperature: t.Optional[float] = attr.field(default=1.0)
top_p: t.Optional[float] = attr.field(default=1)
top_p: t.Optional[float] = attr.field(default=1.0)
n: t.Optional[int] = attr.field(default=1)
stream: t.Optional[bool] = attr.field(default=False)
logprobs: t.Optional[int] = attr.field(default=None)
@@ -23,9 +31,11 @@ class CompletionRequest:
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
presence_penalty: t.Optional[float] = attr.field(default=0.0)
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
best_of: t.Optional[int] = attr.field(default=1)
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
user: t.Optional[str] = attr.field(default=None)
# supported by vLLM and us
top_k: t.Optional[int] = attr.field(default=None)
best_of: t.Optional[int] = attr.field(default=1)
@attr.define
class ChatCompletionRequest:
@@ -33,16 +43,19 @@ class ChatCompletionRequest:
model: str = attr.field(default=None)
functions: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list))
function_calls: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list))
temperature: t.Optional[float] = attr.field(default=1.0)
top_p: t.Optional[float] = attr.field(default=1)
n: t.Optional[int] = attr.field(default=1)
temperature: t.Optional[float] = attr.field(default=None)
top_p: t.Optional[float] = attr.field(default=None)
n: t.Optional[int] = attr.field(default=None)
stream: t.Optional[bool] = attr.field(default=False)
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
max_tokens: t.Optional[int] = attr.field(default=None)
presence_penalty: t.Optional[float] = attr.field(default=0.0)
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
presence_penalty: t.Optional[float] = attr.field(default=None)
frequency_penalty: t.Optional[float] = attr.field(default=None)
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
user: t.Optional[str] = attr.field(default=None)
# supported by vLLM and us
top_k: t.Optional[int] = attr.field(default=None)
best_of: t.Optional[int] = attr.field(default=1)
@attr.define
class LogProbs:
@@ -52,80 +65,90 @@ class LogProbs:
top_logprobs: t.List[t.Dict[str, t.Any]] = attr.field(default=attr.Factory(list))
@attr.define
class CompletionTextChoice:
text: str
index: int
logprobs: LogProbs = attr.field(default=attr.Factory(lambda: LogProbs()))
finish_reason: str = attr.field(default=None)
@attr.define
class Usage:
class UsageInfo:
prompt_tokens: int = attr.field(default=0)
completion_tokens: int = attr.field(default=0)
total_tokens: int = attr.field(default=0)
@attr.define
class CompletionResponse:
choices: t.List[CompletionTextChoice]
class CompletionResponseChoice:
index: int
text: str
logprobs: t.Optional[LogProbs] = None
finish_reason: t.Optional[str] = None
@attr.define
class CompletionResponseStreamChoice:
index: int
text: str
logprobs: t.Optional[LogProbs] = None
finish_reason: t.Optional[str] = None
@attr.define
class CompletionStreamResponse:
model: str
choices: t.List[CompletionResponseStreamChoice]
object: str = 'text_completion'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
usage: Usage = attr.field(default=attr.Factory(lambda: Usage()))
@attr.define
class CompletionResponseStream:
choices: t.List[CompletionTextChoice]
class CompletionResponse:
choices: t.List[CompletionResponseChoice]
model: str
usage: UsageInfo
object: str = 'text_completion'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
LiteralRole = t.Literal['system', 'user', 'assistant']
class Message(t.TypedDict):
role: LiteralRole
content: str
@attr.define
class Delta:
role: t.Optional[LiteralRole] = None
content: t.Optional[str] = None
@attr.define
class ChatMessage:
role: LiteralRole
content: str
@attr.define
class ChatCompletionChoice:
index: int
message: Message
finish_reason: str = attr.field(default=None)
converter.register_unstructure_hook(ChatMessage, lambda msg: {'role': msg.role, 'content': msg.content})
@attr.define
class ChatCompletionStreamChoice:
class ChatCompletionResponseStreamChoice:
index: int
delta: Message
finish_reason: str = attr.field(default=None)
delta: Delta
finish_reason: t.Optional[str] = attr.field(default=None)
@attr.define
class ChatCompletionResponseChoice:
index: int
message: ChatMessage
finish_reason: t.Optional[str] = attr.field(default=None)
@attr.define
class ChatCompletionResponse:
choices: t.List[ChatCompletionChoice]
choices: t.List[ChatCompletionResponseChoice]
model: str
object: str = 'chat.completion'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
usage: Usage = attr.field(default=attr.Factory(lambda: Usage()))
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
usage: UsageInfo = attr.field(default=attr.Factory(lambda: UsageInfo()))
@attr.define
class ChatCompletionResponseStream:
choices: t.List[ChatCompletionStreamChoice]
class ChatCompletionStreamResponse:
choices: t.List[ChatCompletionResponseStreamChoice]
model: str
object: str = 'chat.completion.chunk'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
@attr.define
class ModelCard:
id: str
object: str = 'model'
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
owned_by: str = 'na'
@attr.define
@@ -133,10 +156,14 @@ class ModelList:
object: str = 'list'
data: t.List[ModelCard] = attr.field(factory=list)
def messages_to_prompt(messages: list[Message], model: str, llm_config: openllm_core.LLMConfig) -> str:
conv_template = _conversation.get_conv_template(model, llm_config)
for message in messages:
if message['role'] == 'system': conv_template.set_system_message(message['content'])
else: conv_template.append_message(message['role'], message['content'])
conv_template.append_message('assistant', '')
return conv_template.get_prompt()
async def get_conversation_prompt(request: ChatCompletionRequest, llm_config: openllm_core.LLMConfig) -> str:
conv = llm_config.get_conversation_template()
for message in request.messages:
msg_role = message['role']
if msg_role == 'system': conv.set_system_message(message['content'])
elif msg_role == 'user': conv.append_message(conv.roles[0], message['content'])
elif msg_role == 'assistant': conv.append_message(conv.roles[1], message['content'])
else: raise ValueError(f'Unknown role: {msg_role}')
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], '')
return conv.get_prompt()