From 63d86faa323184f24520d89c3a1e46bb5dd872cc Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Wed, 22 Nov 2023 04:21:13 -0500 Subject: [PATCH] fix(openai): correct stop tokens and finish_reason state (#722) Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --- .../src/openllm_core/_configuration.py | 28 +++++++++ openllm-core/src/openllm_core/_schemas.py | 27 +++++++- openllm-python/src/openllm/_llm.py | 4 +- openllm-python/src/openllm/_runners.py | 28 +-------- .../src/openllm/entrypoints/openai.py | 61 +++++++++++-------- openllm-python/src/openllm/protocol/openai.py | 18 ++++-- openllm-python/src/openllm_cli/entrypoint.py | 1 + 7 files changed, 103 insertions(+), 64 deletions(-) diff --git a/openllm-core/src/openllm_core/_configuration.py b/openllm-core/src/openllm_core/_configuration.py index a0ac9f88..8c33c549 100644 --- a/openllm-core/src/openllm_core/_configuration.py +++ b/openllm-core/src/openllm_core/_configuration.py @@ -280,6 +280,9 @@ class SamplingParams(ReprMixin): repetition_penalty: float length_penalty: float early_stopping: bool + prompt_logprobs: int + logprobs: int + stop: t.Optional[t.Union[str, t.List[str]]] def __init__(self, *, _internal: bool = False, **attrs: t.Any): if not _internal: @@ -293,6 +296,9 @@ class SamplingParams(ReprMixin): _object_setattr(self, 'repetition_penalty', attrs.pop('repetition_penalty', 1.0)) _object_setattr(self, 'length_penalty', attrs.pop('length_penalty', 1.0)) _object_setattr(self, 'early_stopping', attrs.pop('early_stopping', False)) + _object_setattr(self, 'logprobs', attrs.pop('logprobs', 0)) + _object_setattr(self, 'prompt_logprobs', attrs.pop('prompt_logprobs', 0)) + _object_setattr(self, 'stop', attrs.pop('stop', None)) self.__attrs_init__(**attrs) def __getitem__(self, item: str) -> t.Any: @@ -312,6 +318,10 @@ class SamplingParams(ReprMixin): temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, + repetition_penalty=self.repetition_penalty, + logprobs=self.logprobs, + prompt_logprobs=self.prompt_logprobs, + stop=self.stop, **converter.unstructure(self), ) @@ -331,6 +341,14 @@ class SamplingParams(ReprMixin): ) length_penalty = first_not_none(attrs.pop('length_penalty', None), default=generation_config['length_penalty']) early_stopping = first_not_none(attrs.pop('early_stopping', None), default=generation_config['early_stopping']) + logprobs = first_not_none(attrs.pop('logprobs', None), default=generation_config['logprobs']) + prompt_logprobs = first_not_none(attrs.pop('prompt_logprobs', None), default=generation_config['prompt_logprobs']) + stop = attrs.pop('stop', None) + if stop is None: + try: + stop = generation_config['stop'] + except KeyError: + pass return cls( _internal=True, @@ -341,6 +359,9 @@ class SamplingParams(ReprMixin): repetition_penalty=repetition_penalty, length_penalty=length_penalty, early_stopping=early_stopping, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs, + stop=stop, **attrs, ) @@ -1469,7 +1490,14 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): top_p = 1.0 else: top_p = config['top_p'] + try: + stop = config['stop'] + except KeyError: + stop = None + _object_setattr(config.sampling_config, 'stop', stop) _object_setattr(config.sampling_config, 'top_p', top_p) + _object_setattr(config.sampling_config, 'logprobs', config['logprobs']) + _object_setattr(config.sampling_config, 'prompt_logprobs', config['prompt_logprobs']) return config.sampling_config.build() class ctranslate: diff --git a/openllm-core/src/openllm_core/_schemas.py b/openllm-core/src/openllm_core/_schemas.py index 537e97b7..14023053 100644 --- a/openllm-core/src/openllm_core/_schemas.py +++ b/openllm-core/src/openllm_core/_schemas.py @@ -162,16 +162,39 @@ class GenerationOutput(_SchemaMixin): if not data: raise ValueError('No data found from messages.') try: - return converter.structure(orjson.loads(data), cls) + structured = orjson.loads(data) except orjson.JSONDecodeError as e: raise ValueError(f'Failed to parse JSON from SSE message: {data!r}') from e + if structured['prompt_logprobs']: structured['prompt_logprobs'] = [{int(k): v for k,v in it.items()} if it else None for it in structured['prompt_logprobs']] + + return cls( + prompt=structured['prompt'], + finished=structured['finished'], + prompt_token_ids=structured['prompt_token_ids'], + prompt_logprobs=structured['prompt_logprobs'], + request_id=structured['request_id'], + outputs=[ + CompletionChunk( + index=it['index'], + text=it['text'], + token_ids=it['token_ids'], + cumulative_logprob=it['cumulative_logprob'], + finish_reason=it['finish_reason'], + logprobs=[{int(k): v for k,v in s.items()} for s in it['logprobs']] if it['logprobs'] else None, + ) + for it in structured['outputs'] + ], + ) + @classmethod def from_vllm(cls, request_output: vllm.RequestOutput) -> GenerationOutput: return cls( prompt=request_output.prompt, finished=request_output.finished, request_id=request_output.request_id, + prompt_token_ids=request_output.prompt_token_ids, + prompt_logprobs=request_output.prompt_logprobs, outputs=[ CompletionChunk( index=it.index, @@ -183,8 +206,6 @@ class GenerationOutput(_SchemaMixin): ) for it in request_output.outputs ], - prompt_token_ids=request_output.prompt_token_ids, - prompt_logprobs=request_output.prompt_logprobs, ) def model_dump_json(self) -> str: diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index e79cc1d9..2c18c099 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -118,7 +118,7 @@ class LLM(t.Generic[M, T], ReprMixin): previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n'] try: generator = self.runner.generate_iterator.async_stream( - prompt_token_ids, request_id, stop=stop, adapter_name=adapter_name, **config.model_dump(flatten=True) + prompt_token_ids, request_id, stop=list(stop), adapter_name=adapter_name, **config.model_dump(flatten=True) ) except Exception as err: raise RuntimeError(f'Failed to start generation task: {err}') from err @@ -127,8 +127,6 @@ class LLM(t.Generic[M, T], ReprMixin): async for out in generator: generated = GenerationOutput.from_runner(out).with_options(prompt=prompt) delta_outputs = [None] * len(generated.outputs) - if generated.finished: - break for output in generated.outputs: i = output.index delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :] diff --git a/openllm-python/src/openllm/_runners.py b/openllm-python/src/openllm/_runners.py index a9c72ca9..16c4f6bf 100644 --- a/openllm-python/src/openllm/_runners.py +++ b/openllm-python/src/openllm/_runners.py @@ -89,13 +89,7 @@ class CTranslateRunnable(bentoml.Runnable): @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): - stop_ = set() - if isinstance(stop, str) and stop != '': - stop_.add(stop) - elif isinstance(stop, t.Iterable): - stop_.update(stop) - - config, sampling_params = self.config.model_construct_env(stop=list(stop_), **attrs).inference_options(self.llm) + config, sampling_params = self.config.model_construct_env(stop=list(stop), **attrs).inference_options(self.llm) cumulative_logprob, output_token_ids, input_len = 0.0, list(prompt_token_ids), len(prompt_token_ids) tokens = self.tokenizer.convert_ids_to_tokens(prompt_token_ids) @@ -165,18 +159,9 @@ class vLLMRunnable(bentoml.Runnable): @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): - stop_ = set() - if isinstance(stop, str) and stop != '': - stop_.add(stop) - elif isinstance(stop, t.Iterable): - stop_.update(stop) - - config, sampling_params = self.config.model_construct_env(stop=list(stop_), **attrs).inference_options(self.llm) + config, sampling_params = self.config.model_construct_env(stop=stop, **attrs).inference_options(self.llm) async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids): - # XXX: Need to write a hook for serialisation None correctly - if request_output.prompt_logprobs is not None: - request_output.prompt_logprobs = [it if it else {} for it in request_output.prompt_logprobs] yield GenerationOutput.from_vllm(request_output).model_dump_json() @@ -197,14 +182,7 @@ class PyTorchRunnable(bentoml.Runnable): async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): if adapter_name is not None: self.model.set_adapter(adapter_name) - - stop_ = set() - if isinstance(stop, str) and stop != '': - stop_.add(stop) - elif isinstance(stop, t.Iterable): - stop_.update(stop) - - async for generation_output in self.forward(prompt_token_ids, request_id, list(stop_), **attrs): + async for generation_output in self.forward(prompt_token_ids, request_id, list(stop), **attrs): yield generation_output.model_dump_json() async def forward(self, prompt_token_ids, request_id, stop, **attrs): diff --git a/openllm-python/src/openllm/entrypoints/openai.py b/openllm-python/src/openllm/entrypoints/openai.py index 835e78c6..f263c0db 100644 --- a/openllm-python/src/openllm/entrypoints/openai.py +++ b/openllm-python/src/openllm/entrypoints/openai.py @@ -10,7 +10,6 @@ from starlette.applications import Starlette from starlette.responses import JSONResponse, StreamingResponse from starlette.routing import Route -from openllm_core._schemas import SampleLogprobs from openllm_core.utils import converter, gen_random_uuid from ._openapi import add_schema_definitions, append_schemas, apply_schema, get_generator @@ -165,34 +164,38 @@ async def chat_completions(req, llm): logger.error('Error generating completion: %s', err) return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') - def create_stream_response_json(index, text, finish_reason=None): - return jsonify_attr( - ChatCompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[ - ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason) - ], - ) + def create_stream_response_json(index, text, finish_reason=None, usage=None): + response = ChatCompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason) + ], ) + if usage is not None: response.usage = usage + return jsonify_attr(response) - async def completion_stream_generator(): + async def chat_completion_stream_generator(): # first chunk with role for i in range(config['n']): yield f"data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role='assistant'), finish_reason=None)], model=model_name))}\n\n" + previous_num_tokens = [0] * config['n'] async for res in result_generator: for output in res.outputs: yield f'data: {create_stream_response_json(output.index, output.text)}\n\n' + previous_num_tokens[output.index] += len(output.token_ids) if output.finish_reason is not None: - yield f'data: {create_stream_response_json(output.index, "", output.finish_reason)}\n\n' + prompt_tokens = len(res.prompt_token_ids) + usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i]) + yield f'data: {create_stream_response_json(output.index, "", output.finish_reason, usage)}\n\n' yield 'data: [DONE]\n\n' try: # Streaming case if request.stream: - return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') + return StreamingResponse(chat_completion_stream_generator(), media_type='text/event-stream') # Non-streaming case final_result = None texts, token_ids = [[]] * config['n'], [[]] * config['n'] @@ -287,34 +290,38 @@ async def completions(req, llm): # TODO: support use_beam_search stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of']) - def create_stream_response_json(index, text, logprobs=None, finish_reason=None): - return jsonify_attr( - CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[ - CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason) - ], - ) + def create_stream_response_json(index, text, logprobs=None, finish_reason=None, usage=None): + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason) + ], ) + if usage: response.usage = usage + return jsonify_attr(response) async def completion_stream_generator(): previous_num_tokens = [0] * config['n'] + previous_texts = [''] * config['n'] async for res in result_generator: for output in res.outputs: i = output.index if request.logprobs is not None: logprobs = create_logprobs( - token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i] :], llm=llm + token_ids=output.token_ids, id_logprobs=output.logprobs[previous_num_tokens[i]:], initial_text_offset=len(previous_texts[i]), llm=llm ) else: logprobs = None previous_num_tokens[i] += len(output.token_ids) + previous_texts[i] += output.text yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs)}\n\n' if output.finish_reason is not None: logprobs = LogProbs() if request.logprobs is not None else None - yield f'data: {create_stream_response_json(index=i, text="", logprobs=logprobs, finish_reason=output.finish_reason)}\n\n' + prompt_tokens = len(res.prompt_token_ids) + usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i]) + yield f'data: {create_stream_response_json(i, "", logprobs, output.finish_reason, usage)}\n\n' yield 'data: [DONE]\n\n' try: @@ -344,7 +351,7 @@ async def completions(req, llm): for output in final_result.outputs: if request.logprobs is not None: logprobs = create_logprobs( - token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs), llm=llm + token_ids=output.token_ids, id_logprobs=output.logprobs, llm=llm ) else: logprobs = None diff --git a/openllm-python/src/openllm/protocol/openai.py b/openllm-python/src/openllm/protocol/openai.py index 740f69da..9fd29ada 100644 --- a/openllm-python/src/openllm/protocol/openai.py +++ b/openllm-python/src/openllm/protocol/openai.py @@ -5,6 +5,7 @@ import typing as t import attr import openllm_core +from openllm_core._schemas import FinishReason from openllm_core.utils import converter @@ -16,6 +17,9 @@ class ErrorResponse: param: t.Optional[str] = None code: t.Optional[str] = None +def _stop_converter(data: t.Union[str, t.List[str]]) -> t.List[str]: + if not data: return None + return [data] if isinstance(data, str) else data @attr.define class CompletionRequest: @@ -29,7 +33,7 @@ class CompletionRequest: stream: t.Optional[bool] = attr.field(default=False) logprobs: t.Optional[int] = attr.field(default=None) echo: t.Optional[bool] = attr.field(default=False) - stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None) + stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None, converter=_stop_converter) presence_penalty: t.Optional[float] = attr.field(default=0.0) frequency_penalty: t.Optional[float] = attr.field(default=0.0) logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None) @@ -49,7 +53,7 @@ class ChatCompletionRequest: top_p: t.Optional[float] = attr.field(default=None) n: t.Optional[int] = attr.field(default=None) stream: t.Optional[bool] = attr.field(default=False) - stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None) + stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None, converter=_stop_converter) max_tokens: t.Optional[int] = attr.field(default=None) presence_penalty: t.Optional[float] = attr.field(default=None) frequency_penalty: t.Optional[float] = attr.field(default=None) @@ -80,7 +84,7 @@ class CompletionResponseChoice: index: int text: str logprobs: t.Optional[LogProbs] = None - finish_reason: t.Optional[str] = None + finish_reason: t.Optional[FinishReason] = None @attr.define @@ -88,7 +92,7 @@ class CompletionResponseStreamChoice: index: int text: str logprobs: t.Optional[LogProbs] = None - finish_reason: t.Optional[str] = None + finish_reason: t.Optional[FinishReason] = None @attr.define @@ -98,6 +102,7 @@ class CompletionStreamResponse: object: str = 'text_completion' id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl'))) created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic()))) + usage: t.Optional[UsageInfo] = attr.field(default=None) @attr.define @@ -132,14 +137,14 @@ converter.register_unstructure_hook(ChatMessage, lambda msg: {'role': msg.role, class ChatCompletionResponseStreamChoice: index: int delta: Delta - finish_reason: t.Optional[str] = attr.field(default=None) + finish_reason: t.Optional[FinishReason] = None @attr.define class ChatCompletionResponseChoice: index: int message: ChatMessage - finish_reason: t.Optional[str] = attr.field(default=None) + finish_reason: t.Optional[FinishReason] = None @attr.define @@ -159,6 +164,7 @@ class ChatCompletionStreamResponse: object: str = 'chat.completion.chunk' id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl'))) created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic()))) + usage: t.Optional[UsageInfo] = attr.field(default=None) @attr.define diff --git a/openllm-python/src/openllm_cli/entrypoint.py b/openllm-python/src/openllm_cli/entrypoint.py index c9c86d48..9effae62 100644 --- a/openllm-python/src/openllm_cli/entrypoint.py +++ b/openllm-python/src/openllm_cli/entrypoint.py @@ -442,6 +442,7 @@ def start_command( serialisation=serialisation, dtype=dtype, max_model_len=max_model_len, + trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'), ) backend_warning(llm.__llm_backend__)