diff --git a/openllm-core/src/openllm_core/_configuration.py b/openllm-core/src/openllm_core/_configuration.py index b8940140..148f7b63 100644 --- a/openllm-core/src/openllm_core/_configuration.py +++ b/openllm-core/src/openllm_core/_configuration.py @@ -213,6 +213,12 @@ class GenerationConfig(pydantic.BaseModel): description='The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.', alias='max_new_tokens', ) + min_tokens: int = pydantic.Field( + 0, + ge=0, + description='Minimum number of tokens to generate per output sequence before EOS or stop_token_ids can be generated', + alias='min_new_tokens', + ) logprobs: t.Optional[int] = pydantic.Field( None, description='Number of log probabilities to return per output token.' ) @@ -454,6 +460,8 @@ class LLMConfig(pydantic.BaseModel, abc.ABC): @overload def __getitem__(self, item: t.Literal['max_tokens']) -> int: ... @overload + def __getitem__(self, item: t.Literal['min_tokens']) -> int: ... + @overload def __getitem__(self, item: t.Literal['logprobs']) -> t.Optional[int]: ... @overload def __getitem__(self, item: t.Literal['prompt_logprobs']) -> t.Optional[int]: ... diff --git a/openllm-python/src/_openllm_tiny/_llm.py b/openllm-python/src/_openllm_tiny/_llm.py index 3b4ad873..9a813728 100644 --- a/openllm-python/src/_openllm_tiny/_llm.py +++ b/openllm-python/src/_openllm_tiny/_llm.py @@ -171,6 +171,7 @@ class LLM: top_p = 1.0 if config['temperature'] <= 1e-5 else config['top_p'] config = config.model_copy(update=dict(stop=list(stop), stop_token_ids=stop_token_ids, top_p=top_p)) + sampling_params = SamplingParams(**{ k: getattr(config, k, None) for k in set(inspect.signature(SamplingParams).parameters.keys()) }) diff --git a/openllm-python/src/_openllm_tiny/_service.py b/openllm-python/src/_openllm_tiny/_service.py index d0fa6838..f8b9abc6 100644 --- a/openllm-python/src/_openllm_tiny/_service.py +++ b/openllm-python/src/_openllm_tiny/_service.py @@ -92,7 +92,7 @@ class LLMService: stop_token_ids=stop_token_ids, request_id=request_id, ): - yield f'data: {generated.model_dump_json()}\n\n' + yield f'data: {core.GenerationOutput.from_vllm(generated).model_dump_json()}\n\n' yield 'data: [DONE]\n\n' @core.utils.api(route='/v1/metadata')