mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-18 22:55:08 -05:00
chore(style): enable yapf to match with style guidelines
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -5,18 +5,23 @@ from openllm_core.config.configuration_starcoder import EOD, FIM_MIDDLE, FIM_PAD
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.GPT2TokenizerFast"]):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
import torch, transformers
|
||||
torch_dtype, device_map = attrs.pop("torch_dtype", torch.float16), attrs.pop("device_map", "auto")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], "pad_token": EOD})
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs)
|
||||
try: return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
finally: torch.cuda.empty_cache()
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
@@ -26,6 +31,7 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
|
||||
# TODO: We will probably want to return the tokenizer here so that we can manually process this
|
||||
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
|
||||
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", openllm.StoppingCriteriaList([]))
|
||||
|
||||
Reference in New Issue
Block a user