mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-30 18:32:18 -05:00
fix(configuration): Make sure GenerationInput dumped the correct
dictionary for llm_config Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -44,6 +44,14 @@ class GenerationInput(pydantic.BaseModel):
|
||||
llm_config=(llm_config.__class__, ...),
|
||||
)
|
||||
|
||||
# XXX: Need more investigation why llm_config.model_dump is not invoked
|
||||
# recursively when GenerationInput.model_dump is called
|
||||
def model_dump(self, **kwargs: t.Any):
|
||||
"""Override the default model_dump to make sure llm_config is correctly flattened."""
|
||||
dumped = super().model_dump(**kwargs)
|
||||
dumped['llm_config'] = self.llm_config.model_dump(flatten=True)
|
||||
return dumped
|
||||
|
||||
|
||||
class GenerationOutput(pydantic.BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
@@ -25,7 +25,6 @@ svc = bentoml.Service(name=f"llm-{llm_config.__openllm_start_name__}-service", r
|
||||
route="/v1/generate",
|
||||
)
|
||||
async def generate_v1(qa: openllm.GenerationInput) -> openllm.GenerationOutput:
|
||||
print(qa)
|
||||
config = llm_config.with_options(__llm_config__=qa.llm_config).model_dump()
|
||||
responses = await runner.generate.async_run(qa.prompt, **config)
|
||||
return openllm.GenerationOutput(responses=responses, configuration=config)
|
||||
|
||||
@@ -81,7 +81,6 @@ class DollyV2(openllm.LLM):
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=True,
|
||||
**kwargs,
|
||||
).model_dump(flatten=True)
|
||||
|
||||
|
||||
@@ -65,7 +65,6 @@ class FlanT5(openllm.LLM):
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
do_sample: bool = True,
|
||||
temperature: float | None = None,
|
||||
top_k: float | None = None,
|
||||
top_p: float | None = None,
|
||||
@@ -75,7 +74,7 @@ class FlanT5(openllm.LLM):
|
||||
input_ids = t.cast("torch.Tensor", self.tokenizer(prompt, return_tensors="pt").input_ids).to(self.device)
|
||||
result_tensor = self.model.generate(
|
||||
input_ids,
|
||||
do_sample=do_sample,
|
||||
do_sample=True,
|
||||
generation_config=self.config.with_options(
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -57,7 +57,6 @@ class FlaxFlanT5(openllm.LLM):
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
do_sample: bool = True,
|
||||
temperature: float | None = None,
|
||||
top_k: float | None = None,
|
||||
top_p: float | None = None,
|
||||
@@ -67,7 +66,7 @@ class FlaxFlanT5(openllm.LLM):
|
||||
input_ids = self.tokenizer(prompt, return_tensors="np")["input_ids"]
|
||||
result_tensor = self.model.generate(
|
||||
input_ids,
|
||||
do_sample=do_sample,
|
||||
do_sample=True,
|
||||
generation_config=self.config.with_options(
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -57,7 +57,6 @@ class TFFlanT5(openllm.LLM):
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
do_sample: bool = True,
|
||||
temperature: float | None = None,
|
||||
top_k: float | None = None,
|
||||
top_p: float | None = None,
|
||||
@@ -67,7 +66,7 @@ class TFFlanT5(openllm.LLM):
|
||||
input_ids = self.tokenizer(prompt, return_tensors="tf").input_ids
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
do_sample=do_sample,
|
||||
do_sample=True,
|
||||
generation_config=self.config.with_options(
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -128,7 +128,6 @@ class StarCoder(openllm.LLM):
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = True,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
@@ -148,7 +147,7 @@ class StarCoder(openllm.LLM):
|
||||
inputs = t.cast("torch.Tensor", self.tokenizer.encode(prompt, return_tensors="pt")).to(self.device)
|
||||
result_tensor = self.model.generate(
|
||||
inputs,
|
||||
do_sample=do_sample,
|
||||
do_sample=True,
|
||||
generation_config=self.config.with_options(
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
|
||||
Reference in New Issue
Block a user