From 775ddab94106e756bce4956c0bcbd2f0b82eb500 Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Sat, 27 May 2023 05:23:44 -0700 Subject: [PATCH] fix(generation): correct type of top_k to int instead of float Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- src/openllm/models/dolly_v2/modeling_dolly_v2.py | 4 ++-- src/openllm/models/flan_t5/modeling_flan_t5.py | 4 ++-- src/openllm/models/flan_t5/modeling_flax_flan_t5.py | 4 ++-- src/openllm/models/flan_t5/modeling_tf_flan_t5.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/openllm/models/dolly_v2/modeling_dolly_v2.py b/src/openllm/models/dolly_v2/modeling_dolly_v2.py index e26b9a58..3075caf1 100644 --- a/src/openllm/models/dolly_v2/modeling_dolly_v2.py +++ b/src/openllm/models/dolly_v2/modeling_dolly_v2.py @@ -70,7 +70,7 @@ class DollyV2(openllm.LLM): prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, - top_k: float | None = None, + top_k: int| None = None, top_p: float | None = None, **kwargs: t.Any, ) -> tuple[str, dict[str, t.Any]]: @@ -97,7 +97,7 @@ class DollyV2(openllm.LLM): prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, - top_k: float | None = None, + top_k: int| None = None, top_p: float | None = None, **kwargs: t.Any, ): diff --git a/src/openllm/models/flan_t5/modeling_flan_t5.py b/src/openllm/models/flan_t5/modeling_flan_t5.py index c9e51560..6d7eee0b 100644 --- a/src/openllm/models/flan_t5/modeling_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flan_t5.py @@ -43,7 +43,7 @@ class FlanT5(openllm.LLM): prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, - top_k: float | None = None, + top_k: int| None = None, top_p: float | None = None, repetition_penalty: float | None = None, **kwargs: t.Any, @@ -66,7 +66,7 @@ class FlanT5(openllm.LLM): prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, - top_k: float | None = None, + top_k: int| None = None, top_p: float | None = None, repetition_penalty: float | None = None, **kwargs: t.Any, diff --git a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py index b170868e..155401b1 100644 --- a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py @@ -36,7 +36,7 @@ class FlaxFlanT5(openllm.LLM): prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, - top_k: float | None = None, + top_k: int| None = None, top_p: float | None = None, repetition_penalty: float | None = None, **kwargs: t.Any, @@ -58,7 +58,7 @@ class FlaxFlanT5(openllm.LLM): prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, - top_k: float | None = None, + top_k: int| None = None, top_p: float | None = None, repetition_penalty: float | None = None, **kwargs: t.Any, diff --git a/src/openllm/models/flan_t5/modeling_tf_flan_t5.py b/src/openllm/models/flan_t5/modeling_tf_flan_t5.py index f1d150a6..7d724d60 100644 --- a/src/openllm/models/flan_t5/modeling_tf_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_tf_flan_t5.py @@ -36,7 +36,7 @@ class TFFlanT5(openllm.LLM): prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, - top_k: float | None = None, + top_k: int| None = None, top_p: float | None = None, repetition_penalty: float | None = None, **kwargs: t.Any, @@ -58,7 +58,7 @@ class TFFlanT5(openllm.LLM): prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, - top_k: float | None = None, + top_k: int| None = None, top_p: float | None = None, repetition_penalty: float | None = None, **kwargs: t.Any,