fix(stablelm): Ensure passing EOS_TOKEN_ID for generation

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-05-28 14:43:00 +00:00
parent b4403c24b0
commit 8ca488d8fc
3 changed files with 48 additions and 19 deletions

View File

@@ -425,7 +425,7 @@ def start_model_command(
click.secho(
f"Make sure that you have the following dependencies available: {llm.requirements}\n", fg="yellow"
)
click.secho(f"Starting LLM Server for '{model_name}'\n", fg="blue")
click.secho(f"\nStarting LLM Server for '{model_name}'\n", fg="blue")
server_cls = getattr(bentoml, "HTTPServer" if not _serve_grpc else "GrpcServer")
server: bentoml.server.Server = server_cls("_service.py:svc", **server_attrs)
server.timeout = 90

View File

@@ -32,7 +32,9 @@ class StableLMConfig(openllm.LLMConfig, name_type="lowercase"):
class GenerationConfig:
temperature: float = 0.9
max_new_tokens: int = 64
max_new_tokens: int = 128
top_k: int = 0
top_p: float = 0.9
START_STABLELM_COMMAND_DOCSTRING = """\
@@ -58,8 +60,6 @@ SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM will refuse to participate in anything that could harm a human.
""" # noqa
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}
<|USER|>{instruction}<|ASSISTANT|>
""".format(
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>""".format(
system_prompt=SYSTEM_PROMPT, instruction="{instruction}"
)

View File

@@ -16,10 +16,19 @@ from __future__ import annotations
import logging
import typing as t
from transformers import StoppingCriteria, StoppingCriteriaList
import openllm
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = set([50278, 50279, 50277, 1, 0])
return input_ids[0][-1] in stop_ids
if t.TYPE_CHECKING:
import torch
else:
@@ -43,14 +52,22 @@ class StableLM(openllm.LLM):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import_kwargs = {
"torch_dtype": torch.float16,
"load_in_8bit": False,
"device_map": "auto",
}
def preprocess_parameters(
self,
prompt: str,
temperature: float | None = None,
max_new_tokens: int | None = None,
top_k: int | None = None,
top_p: float | None = None,
**attrs: t.Any,
) -> tuple[str, dict[str, t.Any]]:
if "tuned" in self.default_model:
if "tuned" in self._pretrained:
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt)
else:
prompt_text = prompt
@@ -58,11 +75,13 @@ class StableLM(openllm.LLM):
return prompt_text, self.config.model_construct_env(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
**attrs,
).model_dump(flatten=True)
def postprocess_parameters(self, prompt: str, generation_result: str, **_: t.Any) -> str:
return generation_result
def postprocess_parameters(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
return generation_result[0]
@torch.inference_mode()
def generate(
@@ -70,20 +89,30 @@ class StableLM(openllm.LLM):
prompt: str,
temperature: float | None = None,
max_new_tokens: int | None = None,
top_k: int | None = None,
top_p: float | None = None,
**attrs: t.Any,
) -> list[str]:
if not self.model.is_loaded_in_8bit:
self.model.half()
if torch.cuda.is_available():
self.model.half().cuda()
self.model.cuda()
generation_kwargs = {
"do_sample": True,
"generation_config": self.config.model_construct_env(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
**attrs,
).to_generation_config(),
"pad_token_id": self.tokenizer.eos_token_id,
}
if "tuned" in self._pretrained:
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([StopOnTokens()])
inputs = t.cast("torch.Tensor", self.tokenizer(prompt, return_tensors="pt")).to(self.device)
with torch.device(self.device):
result_tensor = self.model.generate(
**inputs,
do_sample=True,
generation_config=self.config.model_construct_env(
temperature=temperature,
max_new_tokens=max_new_tokens,
**attrs,
).to_generation_config(),
)
return self.tokenizer.decode(result_tensor[0], skip_special_tokens=True)
tokens = self.model.generate(**inputs, **generation_kwargs)
return [self.tokenizer.decode(tokens[0], skip_special_tokens=True)]