diff --git a/src/openllm/cli.py b/src/openllm/cli.py index 523da3d4..76018c74 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -425,7 +425,7 @@ def start_model_command( click.secho( f"Make sure that you have the following dependencies available: {llm.requirements}\n", fg="yellow" ) - click.secho(f"Starting LLM Server for '{model_name}'\n", fg="blue") + click.secho(f"\nStarting LLM Server for '{model_name}'\n", fg="blue") server_cls = getattr(bentoml, "HTTPServer" if not _serve_grpc else "GrpcServer") server: bentoml.server.Server = server_cls("_service.py:svc", **server_attrs) server.timeout = 90 diff --git a/src/openllm/models/stablelm/configuration_stablelm.py b/src/openllm/models/stablelm/configuration_stablelm.py index 8b51ee7c..c0fff121 100644 --- a/src/openllm/models/stablelm/configuration_stablelm.py +++ b/src/openllm/models/stablelm/configuration_stablelm.py @@ -32,7 +32,9 @@ class StableLMConfig(openllm.LLMConfig, name_type="lowercase"): class GenerationConfig: temperature: float = 0.9 - max_new_tokens: int = 64 + max_new_tokens: int = 128 + top_k: int = 0 + top_p: float = 0.9 START_STABLELM_COMMAND_DOCSTRING = """\ @@ -58,8 +60,6 @@ SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version) - StableLM will refuse to participate in anything that could harm a human. """ # noqa -DEFAULT_PROMPT_TEMPLATE = """{system_prompt} -<|USER|>{instruction}<|ASSISTANT|> -""".format( +DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>""".format( system_prompt=SYSTEM_PROMPT, instruction="{instruction}" ) diff --git a/src/openllm/models/stablelm/modeling_stablelm.py b/src/openllm/models/stablelm/modeling_stablelm.py index 12d1efa9..4a8fac8f 100644 --- a/src/openllm/models/stablelm/modeling_stablelm.py +++ b/src/openllm/models/stablelm/modeling_stablelm.py @@ -16,10 +16,19 @@ from __future__ import annotations import logging import typing as t +from transformers import StoppingCriteria, StoppingCriteriaList + import openllm from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE + +class StopOnTokens(StoppingCriteria): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + stop_ids = set([50278, 50279, 50277, 1, 0]) + return input_ids[0][-1] in stop_ids + + if t.TYPE_CHECKING: import torch else: @@ -43,14 +52,22 @@ class StableLM(openllm.LLM): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + import_kwargs = { + "torch_dtype": torch.float16, + "load_in_8bit": False, + "device_map": "auto", + } + def preprocess_parameters( self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, + top_k: int | None = None, + top_p: float | None = None, **attrs: t.Any, ) -> tuple[str, dict[str, t.Any]]: - if "tuned" in self.default_model: + if "tuned" in self._pretrained: prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt) else: prompt_text = prompt @@ -58,11 +75,13 @@ class StableLM(openllm.LLM): return prompt_text, self.config.model_construct_env( temperature=temperature, max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, **attrs, ).model_dump(flatten=True) - def postprocess_parameters(self, prompt: str, generation_result: str, **_: t.Any) -> str: - return generation_result + def postprocess_parameters(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: + return generation_result[0] @torch.inference_mode() def generate( @@ -70,20 +89,30 @@ class StableLM(openllm.LLM): prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, + top_k: int | None = None, + top_p: float | None = None, **attrs: t.Any, ) -> list[str]: + if not self.model.is_loaded_in_8bit: + self.model.half() if torch.cuda.is_available(): - self.model.half().cuda() + self.model.cuda() + + generation_kwargs = { + "do_sample": True, + "generation_config": self.config.model_construct_env( + temperature=temperature, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + **attrs, + ).to_generation_config(), + "pad_token_id": self.tokenizer.eos_token_id, + } + if "tuned" in self._pretrained: + generation_kwargs["stopping_criteria"] = StoppingCriteriaList([StopOnTokens()]) inputs = t.cast("torch.Tensor", self.tokenizer(prompt, return_tensors="pt")).to(self.device) with torch.device(self.device): - result_tensor = self.model.generate( - **inputs, - do_sample=True, - generation_config=self.config.model_construct_env( - temperature=temperature, - max_new_tokens=max_new_tokens, - **attrs, - ).to_generation_config(), - ) - return self.tokenizer.decode(result_tensor[0], skip_special_tokens=True) + tokens = self.model.generate(**inputs, **generation_kwargs) + return [self.tokenizer.decode(tokens[0], skip_special_tokens=True)]