fix(openai): correct stop tokens and finish_reason state (#722)

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-22 04:21:13 -05:00
committed by GitHub
parent 06626e7d1e
commit 63d86faa32
7 changed files with 103 additions and 64 deletions

View File

@@ -5,6 +5,7 @@ import typing as t
import attr
import openllm_core
from openllm_core._schemas import FinishReason
from openllm_core.utils import converter
@@ -16,6 +17,9 @@ class ErrorResponse:
param: t.Optional[str] = None
code: t.Optional[str] = None
def _stop_converter(data: t.Union[str, t.List[str]]) -> t.List[str]:
if not data: return None
return [data] if isinstance(data, str) else data
@attr.define
class CompletionRequest:
@@ -29,7 +33,7 @@ class CompletionRequest:
stream: t.Optional[bool] = attr.field(default=False)
logprobs: t.Optional[int] = attr.field(default=None)
echo: t.Optional[bool] = attr.field(default=False)
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None, converter=_stop_converter)
presence_penalty: t.Optional[float] = attr.field(default=0.0)
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
@@ -49,7 +53,7 @@ class ChatCompletionRequest:
top_p: t.Optional[float] = attr.field(default=None)
n: t.Optional[int] = attr.field(default=None)
stream: t.Optional[bool] = attr.field(default=False)
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None)
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None, converter=_stop_converter)
max_tokens: t.Optional[int] = attr.field(default=None)
presence_penalty: t.Optional[float] = attr.field(default=None)
frequency_penalty: t.Optional[float] = attr.field(default=None)
@@ -80,7 +84,7 @@ class CompletionResponseChoice:
index: int
text: str
logprobs: t.Optional[LogProbs] = None
finish_reason: t.Optional[str] = None
finish_reason: t.Optional[FinishReason] = None
@attr.define
@@ -88,7 +92,7 @@ class CompletionResponseStreamChoice:
index: int
text: str
logprobs: t.Optional[LogProbs] = None
finish_reason: t.Optional[str] = None
finish_reason: t.Optional[FinishReason] = None
@attr.define
@@ -98,6 +102,7 @@ class CompletionStreamResponse:
object: str = 'text_completion'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
usage: t.Optional[UsageInfo] = attr.field(default=None)
@attr.define
@@ -132,14 +137,14 @@ converter.register_unstructure_hook(ChatMessage, lambda msg: {'role': msg.role,
class ChatCompletionResponseStreamChoice:
index: int
delta: Delta
finish_reason: t.Optional[str] = attr.field(default=None)
finish_reason: t.Optional[FinishReason] = None
@attr.define
class ChatCompletionResponseChoice:
index: int
message: ChatMessage
finish_reason: t.Optional[str] = attr.field(default=None)
finish_reason: t.Optional[FinishReason] = None
@attr.define
@@ -159,6 +164,7 @@ class ChatCompletionStreamResponse:
object: str = 'chat.completion.chunk'
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
usage: t.Optional[UsageInfo] = attr.field(default=None)
@attr.define