mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-22 14:31:26 -05:00
159 lines
6.2 KiB
Python
159 lines
6.2 KiB
Python
# Copyright 2023 BentoML Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import typing as t
|
|
|
|
import bentoml
|
|
import openllm
|
|
|
|
from ..._generation import StopSequenceCriteria
|
|
|
|
|
|
if t.TYPE_CHECKING:
|
|
import torch
|
|
|
|
import transformers
|
|
else:
|
|
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
|
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
FIM_PREFIX = "<fim-prefix>"
|
|
FIM_MIDDLE = "<fim-middle>"
|
|
FIM_SUFFIX = "<fim-suffix>"
|
|
FIM_PAD = "<fim-pad>"
|
|
EOD = "<|endoftext|>"
|
|
FIM_INDICATOR = "<FILL_HERE>"
|
|
|
|
|
|
class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.GPT2TokenizerFast"]):
|
|
__openllm_internal__ = True
|
|
|
|
def llm_post_init(self):
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
@property
|
|
def import_kwargs(self):
|
|
model_kwds = {
|
|
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
|
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
|
|
}
|
|
tokenizer_kwds = {"padding_side": "left"}
|
|
return model_kwds, tokenizer_kwds
|
|
|
|
def import_model(
|
|
self,
|
|
model_id: str,
|
|
tag: bentoml.Tag,
|
|
*model_args: t.Any,
|
|
tokenizer_kwds: dict[str, t.Any],
|
|
**attrs: t.Any,
|
|
) -> bentoml.Model:
|
|
torch_dtype = attrs.pop("torch_dtype", torch.float16)
|
|
device_map = attrs.pop("device_map", "auto")
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, **tokenizer_kwds)
|
|
tokenizer.add_special_tokens(
|
|
{
|
|
"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
|
|
"pad_token": EOD,
|
|
}
|
|
)
|
|
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs
|
|
)
|
|
try:
|
|
return bentoml.transformers.save_model(tag, model, custom_objects={"tokenizer": tokenizer})
|
|
finally:
|
|
# NOTE: We need to free the cache after saving here so that we can load it back later on.
|
|
torch.cuda.empty_cache()
|
|
|
|
def sanitize_parameters(
|
|
self,
|
|
prompt: str,
|
|
temperature: float | None = None,
|
|
top_p: float | None = None,
|
|
max_new_tokens: int | None = None,
|
|
repetition_penalty: float | None = None,
|
|
**attrs: t.Any,
|
|
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
|
fim_mode = FIM_INDICATOR in prompt
|
|
prefix, suffix = None, None
|
|
if fim_mode:
|
|
try:
|
|
prefix, suffix = prompt.split(FIM_INDICATOR)
|
|
except Exception as err:
|
|
logger.error("Error while processing prompt with FIM mode:\n", exc_info=err)
|
|
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err
|
|
prompt_text = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
|
|
else:
|
|
prompt_text = prompt
|
|
|
|
generation_config = {
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_new_tokens": max_new_tokens,
|
|
"repetition_penalty": repetition_penalty,
|
|
# XXX: This value is currently a hack, need more investigate why the
|
|
# default starcoder doesn't include the same value as santacoder EOD
|
|
"pad_token_id": 49152,
|
|
**attrs,
|
|
}
|
|
|
|
return prompt_text, generation_config, {}
|
|
|
|
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
|
return generation_result[0]
|
|
|
|
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
|
with torch.inference_mode():
|
|
inputs = t.cast("torch.Tensor", self.tokenizer.encode(prompt, return_tensors="pt")).to(self.device)
|
|
result_tensor = self.model.generate(
|
|
inputs,
|
|
do_sample=True,
|
|
pad_token_id=self.tokenizer.eos_token_id,
|
|
# eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder
|
|
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
|
)
|
|
# TODO: We will probably want to return the tokenizer here so that we can manually process this
|
|
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
|
|
return self.tokenizer.batch_decode(
|
|
result_tensor[0],
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=True,
|
|
)
|
|
|
|
def generate_one(
|
|
self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any
|
|
) -> list[dict[t.Literal["generated_text"], str]]:
|
|
max_new_tokens = preprocess_generate_kwds.pop("max_new_tokens", 200)
|
|
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
|
src_len = encoded_inputs["input_ids"].shape[1]
|
|
stopping_criteria = preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([]))
|
|
stopping_criteria.append(StopSequenceCriteria(stop, self.tokenizer))
|
|
outputs = self.model.generate(
|
|
encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria
|
|
)
|
|
|
|
result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
|
|
# Inference API returns the stop sequence
|
|
for stop_seq in stop:
|
|
if result.endswith(stop_seq):
|
|
result = result[: -len(stop_seq)]
|
|
return [{"generated_text": result}]
|