refactor: packages (#249)

This commit is contained in:
Aaron Pham
2023-08-22 08:55:46 -04:00
committed by GitHub
parent a964e659c1
commit 3ffb25a872
148 changed files with 2899 additions and 1937 deletions

View File

@@ -1,34 +1,24 @@
from __future__ import annotations
import logging, typing as t, bentoml, openllm
from openllm.utils import generate_labels
from .configuration_starcoder import EOD, FIM_INDICATOR, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX
if t.TYPE_CHECKING: import torch, transformers
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
logger = logging.getLogger(__name__)
from openllm_core.config.configuration_starcoder import EOD, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX
if t.TYPE_CHECKING: import transformers
class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.GPT2TokenizerFast"]):
__openllm_internal__ = True
@property
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
import torch
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
import torch, transformers
torch_dtype, device_map = attrs.pop("torch_dtype", torch.float16), attrs.pop("device_map", "auto")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
tokenizer.add_special_tokens({"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], "pad_token": EOD})
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs)
try: return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
finally: torch.cuda.empty_cache()
def sanitize_parameters(self, prompt: str, temperature: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, repetition_penalty: float | None = None, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None
if fim_mode:
try: prefix, suffix = prompt.split(FIM_INDICATOR)
except Exception as err: raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err
prompt_text = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
else: prompt_text = prompt
# XXX: This value for pad_token_id is currently a hack, need more investigate why the
# default starcoder doesn't include the same value as santacoder EOD
return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {}
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
import torch
with torch.inference_mode():
# eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder
# NOTE: support fine-tuning starcoder