mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-22 22:39:47 -05:00
171 lines
5.9 KiB
Python
171 lines
5.9 KiB
Python
# Copyright 2023 BentoML Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import typing as t
|
|
|
|
import bentoml
|
|
|
|
import openllm
|
|
|
|
if t.TYPE_CHECKING:
|
|
import torch
|
|
import transformers
|
|
else:
|
|
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
|
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
FIM_PREFIX = "<fim-prefix>"
|
|
FIM_MIDDLE = "<fim-middle>"
|
|
FIM_SUFFIX = "<fim-suffix>"
|
|
FIM_PAD = "<fim-pad>"
|
|
EOD = "<|endoftext|>"
|
|
FIM_INDICATOR = "<FILL_HERE>"
|
|
|
|
|
|
class StarCoder(openllm.LLM):
|
|
__openllm_internal__ = True
|
|
|
|
default_model = "bigcode/starcoder"
|
|
|
|
requirements = ["bitandbytes"]
|
|
|
|
variants = ["bigcode/starcoder", "bigcode/starcoderbase"]
|
|
|
|
device = torch.device("cuda")
|
|
|
|
import_kwargs = {
|
|
"_tokenizer_padding_side": "left",
|
|
"device_map": "auto",
|
|
"load_in_8bit": True,
|
|
"torch_dtype": torch.float16,
|
|
}
|
|
|
|
def import_model(
|
|
self,
|
|
pretrained: str,
|
|
tag: bentoml.Tag,
|
|
*model_args: t.Any,
|
|
tokenizer_kwds: dict[str, t.Any],
|
|
**attrs: t.Any,
|
|
) -> bentoml.Model:
|
|
trust_remote_code = attrs.pop("trust_remote_code", True)
|
|
|
|
torch_dtype = attrs.pop("torch_dtype", torch.float16)
|
|
load_in_8bit = attrs.pop("load_in_8bit", True)
|
|
device_map = attrs.pop("device_map", "auto")
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
pretrained, trust_remote_code=trust_remote_code, **tokenizer_kwds
|
|
)
|
|
tokenizer.add_special_tokens(
|
|
{
|
|
"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
|
|
"pad_token": EOD,
|
|
}
|
|
)
|
|
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
pretrained,
|
|
torch_dtype=torch_dtype,
|
|
load_in_8bit=load_in_8bit,
|
|
device_map=device_map,
|
|
trust_remote_code=trust_remote_code,
|
|
**attrs,
|
|
)
|
|
|
|
try:
|
|
return bentoml.transformers.save_model(str(tag), model, custom_objects={"tokenizer": tokenizer})
|
|
finally:
|
|
import gc
|
|
|
|
# NOTE: We need to free the cache after saving here so that we can load it back later on.
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
def preprocess_parameters(
|
|
self,
|
|
prompt: str,
|
|
temperature: float | None = None,
|
|
top_p: float | None = None,
|
|
max_new_tokens: int | None = None,
|
|
repetition_penalty: float | None = None,
|
|
**attrs: t.Any,
|
|
) -> tuple[str, dict[str, t.Any]]:
|
|
fim_mode = FIM_INDICATOR in prompt
|
|
prefix, suffix = None, None
|
|
if fim_mode:
|
|
try:
|
|
prefix, suffix = prompt.split(FIM_INDICATOR)
|
|
except Exception as err:
|
|
logger.error("Error while processing prompt with FIM mode:\n", exc_info=err)
|
|
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err
|
|
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
|
|
|
|
return prompt, self.config.model_construct_env(
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
max_new_tokens=max_new_tokens,
|
|
# XXX: This value is currently a hack, need more investigate why the
|
|
# default starcoder doesn't include the same value as santacoder EOD
|
|
pad_token_id=49152,
|
|
repetition_penalty=repetition_penalty,
|
|
**attrs,
|
|
).model_dump(flatten=True)
|
|
|
|
def postprocess_parameters(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
|
return generation_result[0]
|
|
|
|
@torch.inference_mode()
|
|
def generate(
|
|
self,
|
|
prompt: str,
|
|
temperature: float | None = None,
|
|
top_p: float | None = None,
|
|
max_new_tokens: int | None = None,
|
|
repetition_penalty: float | None = None,
|
|
**attrs: t.Any,
|
|
) -> list[str]:
|
|
fim_mode = FIM_INDICATOR in prompt
|
|
prefix, suffix = None, None
|
|
if fim_mode:
|
|
try:
|
|
prefix, suffix = prompt.split(FIM_INDICATOR)
|
|
except Exception as err:
|
|
logger.error("Error while processing prompt with FIM mode:\n", exc_info=err)
|
|
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err
|
|
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
|
|
|
|
inputs = t.cast("torch.Tensor", self.tokenizer.encode(prompt, return_tensors="pt")).to(self.device)
|
|
result_tensor = self.model.generate(
|
|
inputs,
|
|
do_sample=True,
|
|
generation_config=self.config.model_construct_env(
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
max_new_tokens=max_new_tokens,
|
|
# XXX: This value is currently a hack, need more investigate why the default starcoder
|
|
# doesn't include the same value as santacoder EOD
|
|
pad_token_id=49152,
|
|
repetition_penalty=repetition_penalty,
|
|
**attrs,
|
|
).to_generation_config(),
|
|
)
|
|
# TODO: We will probably want to return the tokenizer here so that we can manually process this
|
|
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
|
|
return [self.tokenizer.decode(result_tensor[0])]
|