mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-23 06:52:42 -05:00
fix(stablelm): Ensure passing EOS_TOKEN_ID for generation
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -425,7 +425,7 @@ def start_model_command(
|
||||
click.secho(
|
||||
f"Make sure that you have the following dependencies available: {llm.requirements}\n", fg="yellow"
|
||||
)
|
||||
click.secho(f"Starting LLM Server for '{model_name}'\n", fg="blue")
|
||||
click.secho(f"\nStarting LLM Server for '{model_name}'\n", fg="blue")
|
||||
server_cls = getattr(bentoml, "HTTPServer" if not _serve_grpc else "GrpcServer")
|
||||
server: bentoml.server.Server = server_cls("_service.py:svc", **server_attrs)
|
||||
server.timeout = 90
|
||||
|
||||
@@ -32,7 +32,9 @@ class StableLMConfig(openllm.LLMConfig, name_type="lowercase"):
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 64
|
||||
max_new_tokens: int = 128
|
||||
top_k: int = 0
|
||||
top_p: float = 0.9
|
||||
|
||||
|
||||
START_STABLELM_COMMAND_DOCSTRING = """\
|
||||
@@ -58,8 +60,6 @@ SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
""" # noqa
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}
|
||||
<|USER|>{instruction}<|ASSISTANT|>
|
||||
""".format(
|
||||
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>""".format(
|
||||
system_prompt=SYSTEM_PROMPT, instruction="{instruction}"
|
||||
)
|
||||
|
||||
@@ -16,10 +16,19 @@ from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
import openllm
|
||||
|
||||
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
stop_ids = set([50278, 50279, 50277, 1, 0])
|
||||
return input_ids[0][-1] in stop_ids
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
else:
|
||||
@@ -43,14 +52,22 @@ class StableLM(openllm.LLM):
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
import_kwargs = {
|
||||
"torch_dtype": torch.float16,
|
||||
"load_in_8bit": False,
|
||||
"device_map": "auto",
|
||||
}
|
||||
|
||||
def preprocess_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any]]:
|
||||
if "tuned" in self.default_model:
|
||||
if "tuned" in self._pretrained:
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt)
|
||||
else:
|
||||
prompt_text = prompt
|
||||
@@ -58,11 +75,13 @@ class StableLM(openllm.LLM):
|
||||
return prompt_text, self.config.model_construct_env(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
**attrs,
|
||||
).model_dump(flatten=True)
|
||||
|
||||
def postprocess_parameters(self, prompt: str, generation_result: str, **_: t.Any) -> str:
|
||||
return generation_result
|
||||
def postprocess_parameters(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
@@ -70,20 +89,30 @@ class StableLM(openllm.LLM):
|
||||
prompt: str,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> list[str]:
|
||||
if not self.model.is_loaded_in_8bit:
|
||||
self.model.half()
|
||||
if torch.cuda.is_available():
|
||||
self.model.half().cuda()
|
||||
self.model.cuda()
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": True,
|
||||
"generation_config": self.config.model_construct_env(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
**attrs,
|
||||
).to_generation_config(),
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
}
|
||||
if "tuned" in self._pretrained:
|
||||
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([StopOnTokens()])
|
||||
|
||||
inputs = t.cast("torch.Tensor", self.tokenizer(prompt, return_tensors="pt")).to(self.device)
|
||||
with torch.device(self.device):
|
||||
result_tensor = self.model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**attrs,
|
||||
).to_generation_config(),
|
||||
)
|
||||
return self.tokenizer.decode(result_tensor[0], skip_special_tokens=True)
|
||||
tokens = self.model.generate(**inputs, **generation_kwargs)
|
||||
return [self.tokenizer.decode(tokens[0], skip_special_tokens=True)]
|
||||
|
||||
Reference in New Issue
Block a user