From afddaed08c69846fb40542e9050acdea57dfed07 Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Sat, 10 Jun 2023 02:14:13 -0400 Subject: [PATCH] fix(perf): respect per request information remove use_default_prompt_template options add pretrained to list of start help docstring fix flax generation config improve flax and tensorflow implementation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- pyproject.toml | 4 +- src/openllm/_configuration.py | 93 +++++++++++-------- src/openllm/_llm.py | 4 - src/openllm/_service.py | 10 +- src/openllm/cli.py | 11 ++- .../models/chatglm/modeling_chatglm.py | 1 - .../models/dolly_v2/configuration_dolly_v2.py | 1 - .../models/dolly_v2/modeling_dolly_v2.py | 55 ++++++----- .../models/falcon/configuration_falcon.py | 2 - src/openllm/models/falcon/modeling_falcon.py | 15 +-- .../models/flan_t5/modeling_flan_t5.py | 15 +-- .../models/flan_t5/modeling_flax_flan_t5.py | 24 +++-- .../models/stablelm/configuration_stablelm.py | 4 +- .../models/stablelm/modeling_stablelm.py | 5 +- .../models/starcoder/modeling_starcoder.py | 6 -- src/openllm/utils/__init__.py | 1 + src/openllm_client/runtimes/base.py | 6 +- 17 files changed, 146 insertions(+), 111 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cecdb43d..d86d78d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,8 +16,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", @@ -71,7 +69,7 @@ all = ["openllm[fine-tune]", "openllm[flan-t5]", "openllm[chatglm]", "openllm[st chatglm = ["cpm_kernels", "sentencepiece"] falcon = ["einops", "xformers", "safetensors"] fine-tune = ["peft", "bitsandbytes", "datasets", "accelerate"] -flan-t5 = ["flax", "jax", "jaxlib", "tensorflow"] +flan-t5 = ["flax", "jax", "jaxlib", "tensorflow", "keras"] starcoder = ["bitsandbytes"] [project.urls] diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index ca8fea7d..0d061426 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -48,6 +48,7 @@ import inflection import orjson from cattr.gen import make_dict_unstructure_fn, override from click_option_group import optgroup +from deepmerge.merger import Merger import openllm @@ -83,6 +84,15 @@ __all__ = ["LLMConfig"] logger = logging.getLogger(__name__) +config_merger = Merger( + # merge dicts + type_strategies=[(DictStrAny, "merge")], + # override all other types + fallback_strategies=["override"], + # override conflicting types + type_conflict_strategies=["override"], +) + @t.overload def attrs_to_options( @@ -593,13 +603,9 @@ def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConf return generated_cls -USE_DEFAULT_PROMPT_TEMPLATE_DOCSTRING = """\ -Whether a model should use their default prompt template setup. This is useful if -users wants to do some prompt engineering. Default to True. -""" - -# NOTE: This DEFAULT_KEYMAP is a way to dynamically generate attr.field -DEFAULT_LLMCONFIG_ATTRS = (("use_default_prompt_template", True, USE_DEFAULT_PROMPT_TEMPLATE_DOCSTRING, bool),) +# NOTE: This DEFAULT_LLMCONFIG_ATTRS is a way to dynamically generate attr.field +# and will be saved for future use in LLMConfig if we have some shared config. +DEFAULT_LLMCONFIG_ATTRS: tuple[tuple[str, t.Any, str, type[t.Any]], ...] = () @attr.define @@ -652,6 +658,9 @@ class LLMConfig: __openllm_url__: str = Field(None, init=False) """The resolved url for this LLMConfig.""" + __openllm_accepted_keys__: set[str] = Field(None, init=False) + """The accepted keys for this LLMConfig.""" + __openllm_requirements__: list[str] | None = None """The default PyPI requirements needed to run this given LLM. By default, we will depend on bentoml, torch, transformers.""" @@ -674,10 +683,6 @@ class LLMConfig: """The result generated GenerationConfig class for this LLMConfig. This will be used to create the generation_config argument that can be used throughout the lifecycle.""" - # NOTE: The following can be shared accross all LLMConfig subclasses. - use_default_prompt_template: bool = Field(True, init=False) - use_default_prompt_template.__doc__ = USE_DEFAULT_PROMPT_TEMPLATE_DOCSTRING - def __init_subclass__( cls, *, @@ -757,11 +762,15 @@ class LLMConfig: cls.__openllm_attrs__ = tuple(a.name for a in own_attrs) # NOTE: Enable some default attributes that can be shared across all LLMConfig - base_attrs = [ - attr.Attribute.from_counting_attr(k, cls.Field(default, env=field_env_key(k), description=docs), hints) - for k, default, docs, hints in DEFAULT_LLMCONFIG_ATTRS - if k not in cls.__openllm_attrs__ - ] + base_attrs + if len(DEFAULT_LLMCONFIG_ATTRS) > 0: + # NOTE: update the hints for default variables we dynamically added. + hints.update({k: hints for k, _, _, hints in DEFAULT_LLMCONFIG_ATTRS}) + base_attrs = [ + attr.Attribute.from_counting_attr(k, cls.Field(default, env=field_env_key(k), description=docs), hints) + for k, default, docs, hints in DEFAULT_LLMCONFIG_ATTRS + if k not in cls.__openllm_attrs__ + ] + base_attrs + attrs: list[attr.Attribute[t.Any]] = own_attrs + base_attrs # Mandatory vs non-mandatory attr order only matters when they are part of @@ -817,10 +826,10 @@ class LLMConfig: hints.update(t.get_type_hints(cls.generation_class)) - # NOTE: update the hints for default variables we dynamically added. - hints.update({k: hints for k, _, _, hints in DEFAULT_LLMCONFIG_ATTRS}) cls.__openllm_hints__ = hints + cls.__openllm_accepted_keys__ = set(cls.__openllm_attrs__) | set(attr.fields_dict(cls.generation_class)) + @property def name_type(self) -> t.Literal["dasherize", "lowercase"]: return self.__openllm_name_type__ @@ -832,8 +841,10 @@ class LLMConfig: __openllm_extras__: dict[str, t.Any] | None = None, **attrs: t.Any, ): - to_exclude = list(attr.fields_dict(self.generation_class)) + list(self.__openllm_attrs__) - self.__openllm_extras__ = __openllm_extras__ or {k: v for k, v in attrs.items() if k not in to_exclude} + self.__openllm_extras__ = openllm.utils.first_not_none(__openllm_extras__, default={}) + config_merger.merge( + self.__openllm_extras__, {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__} + ) attrs = {k: v for k, v in attrs.items() if k not in self.__openllm_extras__ and v is not None} @@ -844,9 +855,11 @@ class LLMConfig: attrs = {k: v for k, v in attrs.items() if k not in generation_config} - extras = set(attrs).difference(set(attr.fields_dict(self.__class__))) + self.__attrs_init__(**{k: v for k, v in attrs.items() if k in self.__openllm_attrs__}) - self.__attrs_init__(**{k: v for k, v in attrs.items() if k not in extras}) + # The rest update to extras + attrs = {k: v for k, v in attrs.items() if k not in self.__openllm_attrs__} + config_merger.merge(self.__openllm_extras__, attrs) def __repr__(self) -> str: bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={repr(self.generation_class())}" @@ -897,35 +910,35 @@ class LLMConfig: return orjson.dumps(self.model_dump(**kwargs)) @classmethod - def model_construct_env(cls, __llm_config__: LLMConfig | None = None, **attrs: t.Any) -> LLMConfig: + def model_construct_env(cls, **attrs: t.Any) -> LLMConfig: """A helpers that respect configuration values that sets from environment variables for any given configuration class. """ - # NOTE: filter out None values attrs = {k: v for k, v in attrs.items() if v is not None} - if "generation_config" in attrs: - # NOTE: We will need to flatten the attrs dict - generation_config = attrs.pop("generation_config", {}) - attrs.update(generation_config) - env = ModelEnv(cls.__openllm_model_name__) + model_config = ModelEnv(cls.__openllm_model_name__).model_config - env_json_string = os.environ.get(env.model_config, None) + env_json_string = os.environ.get(model_config, None) if env_json_string is not None: try: config_from_env = orjson.loads(env_json_string) except orjson.JSONDecodeError as e: - raise RuntimeError(f"Failed to parse '{env.model_config}' as valid JSON string.") from e - config_from_env.update(attrs) - return bentoml_cattr.structure(config_from_env, cls) + raise RuntimeError(f"Failed to parse '{model_config}' as valid JSON string.") from e + ncls = bentoml_cattr.structure(config_from_env, cls) + else: + ncls = cls() - if __llm_config__ is not None: - # NOTE: We only hit this branch on server-side, to ensure per-request configuration - # is respected. - attrs.update(__llm_config__.model_dump(flatten=True)) + if "generation_config" in attrs: + generation_config = attrs.pop("generation_config") + if not LazyType(DictStrAny).isinstance(generation_config): + raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}") + else: + generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(ncls.generation_class)} - return bentoml_cattr.structure(attrs, cls) + attrs = {k: v for k, v in attrs.items() if k not in generation_config} + ncls.generation_config = attr.evolve(ncls.generation_config, **generation_config) + return attr.evolve(ncls, **attrs) def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, dict[str, t.Any]]: """Parse given click attributes into a LLMConfig and return the remaining click attributes.""" @@ -1013,12 +1026,14 @@ def structure_llm_config(data: dict[str, t.Any], cls: type[LLMConfig]) -> LLMCon raise RuntimeError(f"Expected a dictionary, but got {type(data)}") cls_attrs = {k: v for k, v in data.items() if k in cls.__openllm_attrs__} + generation_cls_fields = attr.fields_dict(cls.generation_class) if "generation_config" in data: generation_config = data.pop("generation_config") if not LazyType(DictStrAny).isinstance(generation_config): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}") + config_merger.merge(generation_config, {k: v for k, v in data.items() if k in generation_cls_fields}) else: - generation_config = {k: v for k, v in data.items() if k in attr.fields_dict(cls.generation_class)} + generation_config = {k: v for k, v in data.items() if k in generation_cls_fields} not_extras = list(cls_attrs) + list(generation_config) # The rest should be passed to extras data = {k: v for k, v in data.items() if k not in not_extras} diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index ce6eecb6..050d028d 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -230,11 +230,7 @@ class LLMInterface(ABC): It takes a prompt that is given by the user, attrs that can be parsed with the prompt. - NOTE: the attrs should also handle the following default attributes from all LLMConfig: - - use_default_prompt_template - Returns a tuple of three items: - - The processed prompt text depending on `use_default_prompt_template` - The attributes dictionary that can be passed into LLMConfig to generate a GenerationConfig - The attributes dictionary that will be passed into `self.postprocess_generate`. """ diff --git a/src/openllm/_service.py b/src/openllm/_service.py index 3e46e550..2906752a 100644 --- a/src/openllm/_service.py +++ b/src/openllm/_service.py @@ -39,14 +39,14 @@ svc = bentoml.Service(name=f"llm-{llm_config.__openllm_start_name__}-service", r @svc.api( - input=bentoml.io.JSON.from_sample(sample={"prompt": "", "llm_config": {}}), - output=bentoml.io.JSON.from_sample(sample={"responses": [], "configuration": {}}), + input=bentoml.io.JSON.from_sample(sample={"prompt": "", "llm_config": llm_config.model_dump()}), + output=bentoml.io.JSON.from_sample(sample={"responses": [], "configuration": llm_config.model_dump()}), route="/v1/generate", ) async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput: - qa = openllm.GenerationInput.for_model(model)(**input_dict) - config = llm_config.model_construct_env(__llm_config__=qa.llm_config).model_dump() - responses = await runner.generate.async_run(qa.prompt, **config) + qa_inputs = openllm.GenerationInput.for_model(model)(**input_dict) + config = qa_inputs.llm_config.model_dump() + responses = await runner.generate.async_run(qa_inputs.prompt, **config) return openllm.GenerationOutput(responses=responses, configuration=config) diff --git a/src/openllm/cli.py b/src/openllm/cli.py index c2bd6b27..eb2db379 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -314,15 +314,20 @@ def start_model_command( configure_logging() ModelEnv = openllm.utils.ModelEnv(model_name) + llm_config = openllm.AutoConfig.for_model(model_name) + + docstring = f"""\ +{ModelEnv.start_docstring} +\b +The available pretrained models to use with '{model_name}' are: {openllm.AutoLLM.for_model(model_name).pretrained} +""" command_attrs: dict[str, t.Any] = { "name": ModelEnv.model_name, "context_settings": _context_settings or {}, "short_help": f"Start a LLMServer for '{model_name}' ('--help' for more details)", - "help": ModelEnv.start_docstring, + "help": docstring, } - llm_config = openllm.AutoConfig.for_model(model_name) - aliases: list[str] = [] if llm_config.name_type == "dasherize": aliases.append(llm_config.__openllm_start_name__) diff --git a/src/openllm/models/chatglm/modeling_chatglm.py b/src/openllm/models/chatglm/modeling_chatglm.py index c519e24c..86a605cb 100644 --- a/src/openllm/models/chatglm/modeling_chatglm.py +++ b/src/openllm/models/chatglm/modeling_chatglm.py @@ -110,7 +110,6 @@ class ChatGLM(openllm.LLM): "num_beams": num_beams, "top_p": top_p, "temperature": temperature, - "use_default_prompt_template": use_default_prompt_template, **attrs, } diff --git a/src/openllm/models/dolly_v2/configuration_dolly_v2.py b/src/openllm/models/dolly_v2/configuration_dolly_v2.py index 38aa9881..42a4b0bb 100644 --- a/src/openllm/models/dolly_v2/configuration_dolly_v2.py +++ b/src/openllm/models/dolly_v2/configuration_dolly_v2.py @@ -42,7 +42,6 @@ class DollyV2Config( return_full_text: bool = openllm.LLMConfig.Field( False, description="Whether to return the full prompt to the users." ) - use_default_prompt_template: bool = False class GenerationConfig: temperature: float = 0.9 diff --git a/src/openllm/models/dolly_v2/modeling_dolly_v2.py b/src/openllm/models/dolly_v2/modeling_dolly_v2.py index e6f3ffb8..722a41dc 100644 --- a/src/openllm/models/dolly_v2/modeling_dolly_v2.py +++ b/src/openllm/models/dolly_v2/modeling_dolly_v2.py @@ -33,6 +33,9 @@ logger = logging.getLogger(__name__) class DollyV2(openllm.LLM): + if t.TYPE_CHECKING: + config: openllm.DollyV2Config + __openllm_internal__ = True default_model = "databricks/dolly-v2-3b" @@ -58,12 +61,20 @@ class DollyV2(openllm.LLM): torch_dtype=torch_dtype, device_map=device_map, ) - return bentoml.transformers.save_model( - tag, - pipeline, - custom_objects={"tokenizer": tokenizer}, - external_modules=[importlib.import_module(pipeline.__module__)], - ) + try: + return bentoml.transformers.save_model( + tag, + pipeline, + custom_objects={"tokenizer": tokenizer}, + external_modules=[importlib.import_module(pipeline.__module__)], + ) + finally: + import gc + + gc.collect() + + if openllm.utils.is_torch_available() and torch.cuda.is_available(): + torch.cuda.empty_cache() def sanitize_parameters( self, @@ -72,39 +83,37 @@ class DollyV2(openllm.LLM): temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, - use_default_prompt_template: bool = False, **attrs: t.Any, ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: - if use_default_prompt_template: - prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt) - else: - prompt_text = prompt - # NOTE: The rest of attrs should be kwargs for GenerationConfig generate_kwargs = { "max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, - "use_default_prompt_template": use_default_prompt_template, **attrs, } - return prompt_text, generate_kwargs, {} + return prompt, generate_kwargs, {} - def postprocess_generate(self, prompt: str, generation_result: str, **_: t.Any) -> str: - return generation_result + def postprocess_generate( + self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any + ) -> str: + return generation_result[0]["generated_text"] @torch.inference_mode() - def generate(self, prompt: str, **attrs: t.Any) -> str: + def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]: self.model.tokenizer = self.tokenizer - llm_config: openllm.DollyV2Config = self.config.model_construct_env(**attrs) - decoded = self.model(prompt, generation_config=llm_config.to_generation_config()) + llm_config = self.config.model_construct_env(**attrs) + decoded: list[dict[t.Literal["generated_text"], str]] = self.model( + prompt, generation_config=llm_config.to_generation_config() + ) - # If the full text is requested, then append the decoded text to the original instruction. - # This technically isn't the full text, as we format the instruction in the prompt the model has been - # trained on, but to the client it will appear to be the full text. if llm_config.return_full_text: - decoded = f"{DEFAULT_PROMPT_TEMPLATE.format(prompt)}\n{decoded}" + return [ + {k: f"{DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt)}\n{generated}"} + for i in decoded + for k, generated in i.items() + ] return decoded diff --git a/src/openllm/models/falcon/configuration_falcon.py b/src/openllm/models/falcon/configuration_falcon.py index b7ae9b8a..bdd2e453 100644 --- a/src/openllm/models/falcon/configuration_falcon.py +++ b/src/openllm/models/falcon/configuration_falcon.py @@ -32,8 +32,6 @@ class FalconConfig( Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information. """ - use_default_prompt_template: bool = False - class GenerationConfig: max_new_tokens: int = 200 top_k: int = 10 diff --git a/src/openllm/models/falcon/modeling_falcon.py b/src/openllm/models/falcon/modeling_falcon.py index e75b6dac..25e09898 100644 --- a/src/openllm/models/falcon/modeling_falcon.py +++ b/src/openllm/models/falcon/modeling_falcon.py @@ -88,17 +88,20 @@ class Falcon(openllm.LLM): **attrs: t.Any, ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: if use_default_prompt_template: - prompt_variables = { - k: v - for k, v in attrs.items() - if k in default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE) - } + template_variables = default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE) + prompt_variables = {k: v for k, v in attrs.items() if k in template_variables} if "instruction" in prompt_variables: raise RuntimeError( "'instruction' should be passed as the first argument instead of " "kwargs when 'use_default_prompt_template=True'" ) - prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables) + try: + prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables) + except KeyError as e: + raise RuntimeError( + f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " + "Use 'use_default_prompt_template=False' to disable the default prompt template." + ) else: prompt_text = prompt diff --git a/src/openllm/models/flan_t5/modeling_flan_t5.py b/src/openllm/models/flan_t5/modeling_flan_t5.py index 23083320..a87b01ae 100644 --- a/src/openllm/models/flan_t5/modeling_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flan_t5.py @@ -53,17 +53,20 @@ class FlanT5(openllm.LLM): **attrs: t.Any, ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: if use_default_prompt_template: - prompt_variables = { - k: v - for k, v in attrs.items() - if k in default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE) - } + template_variables = default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE) + prompt_variables = {k: v for k, v in attrs.items() if k in template_variables} if "instruction" in prompt_variables: raise RuntimeError( "'instruction' should be passed as the first argument " "instead of kwargs when 'use_default_prompt_template=True'" ) - prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables) + try: + prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables) + except KeyError as e: + raise RuntimeError( + f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " + "Use 'use_default_prompt_template=False' to disable the default prompt template." + ) else: prompt_text = prompt diff --git a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py index 7df58040..f7827c7b 100644 --- a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py @@ -42,30 +42,38 @@ class FlaxFlanT5(openllm.LLM): top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, + decoder_start_token_id: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any, ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: if use_default_prompt_template: - prompt_variables = { - k: v - for k, v in attrs.items() - if k in default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE) - } + template_variables = default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE) + prompt_variables = {k: v for k, v in attrs.items() if k in template_variables} if "instruction" in prompt_variables: raise RuntimeError( "'instruction' should be passed as the first argument " "instead of kwargs when 'use_default_prompt_template=True'" ) - prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables) + try: + prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables) + except KeyError as e: + raise RuntimeError( + f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " + "Use 'use_default_prompt_template=False' to disable the default prompt template." + ) else: prompt_text = prompt + if decoder_start_token_id is None: + decoder_start_token_id = 0 + generation_config = { "max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, + "decoder_start_token_id": decoder_start_token_id, } return prompt_text, generation_config, {} @@ -73,11 +81,15 @@ class FlaxFlanT5(openllm.LLM): return generation_result[0] def generate(self, prompt: str, **attrs: t.Any) -> list[str]: + # XXX: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main + # as it is required for encoder-decoder generation. + decoder_start_token_id = attrs.pop("decoder_start_token_id", 0) input_ids = self.tokenizer(prompt, return_tensors="np")["input_ids"] result_tensor = self.model.generate( input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), + decoder_start_token_id=decoder_start_token_id, ) return self.tokenizer.batch_decode( result_tensor.sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True diff --git a/src/openllm/models/stablelm/configuration_stablelm.py b/src/openllm/models/stablelm/configuration_stablelm.py index 2b505aef..72f52241 100644 --- a/src/openllm/models/stablelm/configuration_stablelm.py +++ b/src/openllm/models/stablelm/configuration_stablelm.py @@ -65,6 +65,4 @@ SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version) - StableLM will refuse to participate in anything that could harm a human. """ # noqa -DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>""".format( - system_prompt=SYSTEM_PROMPT, instruction="{instruction}" -) +DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>""" diff --git a/src/openllm/models/stablelm/modeling_stablelm.py b/src/openllm/models/stablelm/modeling_stablelm.py index 3bdc38e1..62944f17 100644 --- a/src/openllm/models/stablelm/modeling_stablelm.py +++ b/src/openllm/models/stablelm/modeling_stablelm.py @@ -21,7 +21,7 @@ from transformers import StoppingCriteria, StoppingCriteriaList import openllm from ..._prompt import default_formatter -from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE +from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE, SYSTEM_PROMPT class StopOnTokens(StoppingCriteria): @@ -81,7 +81,8 @@ class StableLM(openllm.LLM): "'instruction' should be passed as the first argument " "instead of kwargs when 'use_default_prompt_template=True'" ) - prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables) + system_prompt = prompt_variables.pop("system_prompt", SYSTEM_PROMPT) + prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, system_prompt=system_prompt) else: prompt_text = prompt diff --git a/src/openllm/models/starcoder/modeling_starcoder.py b/src/openllm/models/starcoder/modeling_starcoder.py index a605a94a..70c4214d 100644 --- a/src/openllm/models/starcoder/modeling_starcoder.py +++ b/src/openllm/models/starcoder/modeling_starcoder.py @@ -20,8 +20,6 @@ import bentoml import openllm -from .configuration_starcoder import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import torch import transformers @@ -103,7 +101,6 @@ class StarCoder(openllm.LLM): top_p: float | None = None, max_new_tokens: int | None = None, repetition_penalty: float | None = None, - use_default_prompt_template: bool = True, **attrs: t.Any, ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: fim_mode = FIM_INDICATOR in prompt @@ -129,9 +126,6 @@ class StarCoder(openllm.LLM): **attrs, } - if use_default_prompt_template: - prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt_text) - return prompt_text, generation_config, {} def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: diff --git a/src/openllm/utils/__init__.py b/src/openllm/utils/__init__.py index 50ccea13..3d883698 100644 --- a/src/openllm/utils/__init__.py +++ b/src/openllm/utils/__init__.py @@ -34,6 +34,7 @@ from bentoml._internal.types import LazyType as LazyType from bentoml._internal.utils import LazyLoader as LazyLoader from bentoml._internal.utils import bentoml_cattr as bentoml_cattr from bentoml._internal.utils import copy_file_to_fs_folder as copy_file_to_fs_folder +from bentoml._internal.utils import first_not_none as first_not_none from bentoml._internal.utils import pkg as pkg from bentoml._internal.utils import reserve_free_port as reserve_free_port from bentoml._internal.utils import resolve_user_filepath as resolve_user_filepath diff --git a/src/openllm_client/runtimes/base.py b/src/openllm_client/runtimes/base.py index 289f10f7..4874de3d 100644 --- a/src/openllm_client/runtimes/base.py +++ b/src/openllm_client/runtimes/base.py @@ -161,7 +161,11 @@ class BaseAsyncClient(ClientMixin): ... async def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str: - return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(prompt, **attrs) + # NOTE: We set use_default_prompt_template to False for now. + use_default_prompt_template = attrs.pop("use_default_prompt_template", False) + return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare( + prompt, use_default_prompt_template=use_default_prompt_template, **attrs + ) inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)) res = await self.acall("generate", inputs) r = openllm.GenerationOutput(**res)