mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-23 08:28:24 -04:00
fix(dolly-v2): using pipeline for latest implementation
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -741,10 +741,15 @@ class LLMConfig:
|
||||
|
||||
base_attrs, base_attr_map = _collect_base_attrs(cls, {a.name for a in own_attrs})
|
||||
|
||||
# __openllm_attrs__ is a tracking tuple[attr.Attribute[t.Any]]
|
||||
# that we construct ourself.
|
||||
cls.__openllm_attrs__ = tuple(a.name for a in own_attrs)
|
||||
|
||||
# NOTE: Enable some default attributes that can be shared across all LLMConfig
|
||||
base_attrs = [
|
||||
attr.Attribute.from_counting_attr(k, cls.Field(default, env=field_env_key(k), description=docs), hints)
|
||||
for k, default, docs, hints in DEFAULT_LLMCONFIG_ATTRS
|
||||
if k not in cls.__openllm_attrs__
|
||||
] + base_attrs
|
||||
attrs: list[attr.Attribute[t.Any]] = own_attrs + base_attrs
|
||||
|
||||
@@ -776,9 +781,6 @@ class LLMConfig:
|
||||
_has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False))
|
||||
_has_post_init = bool(getattr(cls, "__attrs_post_init__", False))
|
||||
|
||||
# __openllm_attrs__ is a tracking tuple[attr.Attribute[t.Any]]
|
||||
# that we construct ourself.
|
||||
cls.__openllm_attrs__ = tuple(a.name for a in attrs)
|
||||
AttrsTuple = _make_attr_tuple_class(cls.__name__, cls.__openllm_attrs__)
|
||||
# NOTE: generate a __attrs_init__ for the subclass
|
||||
cls.__attrs_init__ = _add_method_dunders(
|
||||
|
||||
@@ -20,7 +20,7 @@ from __future__ import annotations
|
||||
import openllm
|
||||
|
||||
|
||||
class DollyV2Config(openllm.LLMConfig, default_timeout=3600000):
|
||||
class DollyV2Config(openllm.LLMConfig, default_timeout=3600000, trust_remote_code=True):
|
||||
"""Databricks’ Dolly is an instruction-following large language model trained on the Databricks
|
||||
machine learning platform that is licensed for commercial use.
|
||||
|
||||
@@ -37,6 +37,7 @@ class DollyV2Config(openllm.LLMConfig, default_timeout=3600000):
|
||||
return_full_text: bool = openllm.LLMConfig.Field(
|
||||
False, description="Whether to return the full prompt to the users."
|
||||
)
|
||||
use_default_prompt_template: bool = False
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
|
||||
@@ -13,13 +13,16 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE, END_KEY, RESPONSE_KEY
|
||||
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
@@ -58,16 +61,36 @@ class DollyV2(openllm.LLM):
|
||||
|
||||
default_model = "databricks/dolly-v2-3b"
|
||||
|
||||
load_in_mha = False # NOTE: disable bettertransformer for dolly
|
||||
|
||||
variants = ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"]
|
||||
|
||||
import_kwargs = {
|
||||
"device_map": "auto",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"_tokenizer_padding_size": "left",
|
||||
}
|
||||
import_kwargs = {"device_map": "auto", "torch_dtype": torch.bfloat16, "_tokenizer_padding_side": "left"}
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def import_model(
|
||||
self, pretrained: str, tag: bentoml.Tag, *model_args: t.Any, tokenizer_kwds: dict[str, t.Any], **attrs: t.Any
|
||||
) -> bentoml.Model:
|
||||
trust_remote_code = attrs.pop("trust_remote_code", True)
|
||||
torch_dtype = attrs.pop("torch_dtype", torch.bfloat16)
|
||||
device_map = attrs.pop("device_map", "auto")
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained, **tokenizer_kwds)
|
||||
pipeline = transformers.pipeline(
|
||||
model=pretrained,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
)
|
||||
return bentoml.transformers.save_model(
|
||||
tag,
|
||||
pipeline,
|
||||
custom_objects={"tokenizer": tokenizer},
|
||||
external_modules=[importlib.import_module(pipeline.__module__)],
|
||||
)
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -75,7 +98,7 @@ class DollyV2(openllm.LLM):
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if use_default_prompt_template:
|
||||
@@ -95,117 +118,19 @@ class DollyV2(openllm.LLM):
|
||||
|
||||
return prompt_text, generate_kwargs, {}
|
||||
|
||||
def postprocess_generate(
|
||||
self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any
|
||||
) -> str:
|
||||
return generation_result[0]["generated_text"]
|
||||
def postprocess_generate(self, prompt: str, generation_result: str, **_: t.Any) -> str:
|
||||
return generation_result
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, prompt: str, **attrs: t.Any):
|
||||
"""This is a implementation of InstructionTextGenerationPipeline from databricks."""
|
||||
tokenizer_response_key = next(
|
||||
(token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None
|
||||
)
|
||||
response_key_token_id = None
|
||||
end_key_token_id = None
|
||||
eos_token_id = None
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> str:
|
||||
self.model.tokenizer = self.tokenizer
|
||||
llm_config: openllm.DollyV2Config = self.config.model_construct_env(**attrs)
|
||||
decoded = self.model(prompt, do_sample=True, generation_config=llm_config.to_generation_config())
|
||||
|
||||
if tokenizer_response_key:
|
||||
try:
|
||||
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
|
||||
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
|
||||
# If the full text is requested, then append the decoded text to the original instruction.
|
||||
# This technically isn't the full text, as we format the instruction in the prompt the model has been
|
||||
# trained on, but to the client it will appear to be the full text.
|
||||
if llm_config.return_full_text:
|
||||
decoded = f"{DEFAULT_PROMPT_TEMPLATE.format(prompt)}\n{decoded}"
|
||||
|
||||
# Ensure generation stops once it generates "### End"
|
||||
eos_token_id = end_key_token_id
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
attention_mask = inputs.get("attention_mask", None)
|
||||
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
in_b = 1
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
|
||||
with torch.device(self.device):
|
||||
generated_sequence = self.model.generate(
|
||||
input_ids=input_ids.to(self.device),
|
||||
attention_mask=attention_mask.to(self.device) if attention_mask is not None else None,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
eos_token_id=eos_token_id,
|
||||
generation_config=llm_config.to_generation_config(),
|
||||
)
|
||||
|
||||
out_b = generated_sequence.shape[0]
|
||||
|
||||
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])[0]
|
||||
if torch.cuda.is_available():
|
||||
generated_sequence = generated_sequence.cpu()
|
||||
|
||||
generated_sequence: list[list[int]] = generated_sequence.numpy().tolist()
|
||||
records: list[dict[str, t.Any]] = []
|
||||
for sequence in generated_sequence:
|
||||
# The response will be set to this variable if we can identify it.
|
||||
decoded = None
|
||||
|
||||
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
|
||||
if response_key_token_id and end_key_token_id:
|
||||
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
|
||||
# prompt, we should definitely find it. We will return the tokens found after this token.
|
||||
try:
|
||||
response_pos = sequence.index(response_key_token_id)
|
||||
except ValueError:
|
||||
logger.warning(f"Could not find response key {response_key_token_id} in: {sequence}")
|
||||
response_pos = None
|
||||
|
||||
if response_pos:
|
||||
# Next find where "### End" is located. The model has been trained to end its responses with this
|
||||
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
||||
# this token, as the response could be truncated. If we don't find it then just return everything
|
||||
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
|
||||
try:
|
||||
end_pos = sequence.index(end_key_token_id)
|
||||
except ValueError:
|
||||
end_pos = None
|
||||
|
||||
decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
|
||||
|
||||
if not decoded:
|
||||
# Otherwise we'll decode everything and use a regex to find the response and end.
|
||||
|
||||
fully_decoded = self.tokenizer.decode(sequence)
|
||||
|
||||
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
||||
# end.
|
||||
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
|
||||
|
||||
if m:
|
||||
decoded = m.group(1).strip()
|
||||
else:
|
||||
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
||||
# return everything after "### Response:".
|
||||
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
|
||||
if m:
|
||||
decoded = m.group(1).strip()
|
||||
else:
|
||||
logger.warning(f"Failed to find response in:\n{fully_decoded}")
|
||||
|
||||
# If the full text is requested, then append the decoded text to the original instruction.
|
||||
# This technically isn't the full text, as we format the instruction in the prompt the model has been
|
||||
# trained on, but to the client it will appear to be the full text.
|
||||
if llm_config.return_full_text:
|
||||
decoded = f"{prompt}\n{decoded}"
|
||||
|
||||
rec = {"generated_text": decoded}
|
||||
|
||||
records.append(rec)
|
||||
|
||||
return records
|
||||
return decoded
|
||||
|
||||
Reference in New Issue
Block a user