mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-14 19:40:43 -04:00
fix(load): tokenizer and adapter within a BentoLLM (#88)
This commit is contained in:
@@ -28,7 +28,7 @@ svc = bentoml.Service(name="llm-service", runners=[llm_runner])
|
||||
|
||||
@svc.on_startup
|
||||
def download(_: bentoml.Context):
|
||||
llm_runner.llm.ensure_model_id_exists()
|
||||
llm_runner.download_model()
|
||||
|
||||
|
||||
@svc.api(input=bentoml.io.Text(), output=bentoml.io.Text())
|
||||
|
||||
@@ -68,7 +68,7 @@ svc = bentoml.Service("fb-ads-copy", runners=[llm.runner])
|
||||
|
||||
@svc.on_startup
|
||||
def download(_: bentoml.Context):
|
||||
llm.runner.llm.ensure_model_id_exists()
|
||||
llm.runner.download_model()
|
||||
|
||||
|
||||
SAMPLE_INPUT = Query(
|
||||
|
||||
@@ -63,9 +63,9 @@ all = [
|
||||
"openllm[chatglm]",
|
||||
"openllm[starcoder]",
|
||||
"openllm[falcon]",
|
||||
"openllm[agents]",
|
||||
"openllm[fine-tune]",
|
||||
"openllm[openai]",
|
||||
"openllm[fine-tune]",
|
||||
"openllm[agents]",
|
||||
"openllm[flan-t5]",
|
||||
]
|
||||
chatglm = ["cpm_kernels", "sentencepiece"]
|
||||
|
||||
@@ -31,6 +31,8 @@ from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import attr
|
||||
import cloudpickle
|
||||
import fs
|
||||
import inflection
|
||||
import orjson
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -38,6 +40,7 @@ from huggingface_hub import hf_hub_download
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.frameworks.transformers import make_default_signatures
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
from bentoml._internal.models.model import ModelContext
|
||||
from bentoml._internal.models.model import ModelOptions
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
@@ -94,6 +97,9 @@ if t.TYPE_CHECKING:
|
||||
def __call__(self, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
...
|
||||
|
||||
def download_model(self, quiet: bool = ...) -> None:
|
||||
...
|
||||
|
||||
class PreTrainedProtocol(t.Protocol):
|
||||
@property
|
||||
def framework(self) -> str:
|
||||
@@ -1146,13 +1152,27 @@ class LLM(LLMInterface[_M, _T], ReprMixin):
|
||||
"""The tokenizer to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it."""
|
||||
if self.__llm_tokenizer__ is None:
|
||||
if self.model_custom_path:
|
||||
resolved = self.ensure_model_id_exists(quiet=True)
|
||||
assert isinstance(resolved, str)
|
||||
self.__llm_tokenizer__ = transformers.AutoTokenizer.from_pretrained(
|
||||
resolved,
|
||||
trust_remote_code=self.__llm_trust_remote_code__,
|
||||
**self._tokenizer_attrs,
|
||||
)
|
||||
# safe cast here since model is a custom path
|
||||
resolve_fs = fs.open_fs(t.cast(str, self.ensure_model_id_exists(quiet=True)))
|
||||
if resolve_fs.isfile(CUSTOM_OBJECTS_FILENAME):
|
||||
# this branch is hit when loading within the bento.
|
||||
with resolve_fs.open(CUSTOM_OBJECTS_FILENAME, "rb") as cofile:
|
||||
try:
|
||||
self.__llm_tokenizer__ = cloudpickle.load(t.cast("t.IO[bytes]", cofile))["tokenizer"]
|
||||
except KeyError:
|
||||
# This could happen if users implement their own import_model
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"Model does not have tokenizer. Make sure to save \
|
||||
the tokenizer within the model via 'custom_objects'.\
|
||||
For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))"
|
||||
)
|
||||
else:
|
||||
self.__llm_tokenizer__ = transformers.AutoTokenizer.from_pretrained(
|
||||
resolve_fs.getsyspath("/"),
|
||||
trust_remote_code=self.__llm_trust_remote_code__,
|
||||
**self._tokenizer_attrs,
|
||||
)
|
||||
resolve_fs.close()
|
||||
else:
|
||||
try:
|
||||
if self.__llm_custom_tokenizer__:
|
||||
@@ -1166,7 +1186,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin):
|
||||
the tokenizer within the model via 'custom_objects'.\
|
||||
For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))"
|
||||
)
|
||||
return self.__llm_tokenizer__
|
||||
return t.cast(_T, self.__llm_tokenizer__)
|
||||
|
||||
def _transpose_adapter_mapping(
|
||||
self,
|
||||
@@ -1377,6 +1397,12 @@ class LLM(LLMInterface[_M, _T], ReprMixin):
|
||||
"result": {},
|
||||
"error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'",
|
||||
}
|
||||
if self.__llm_adapter_map__ is None:
|
||||
return {
|
||||
"success": False,
|
||||
"result": {},
|
||||
"error_msg": "No adapters available for current running server.",
|
||||
}
|
||||
if not isinstance(self.model, peft.PeftModel):
|
||||
return {"success": False, "result": {}, "error_msg": "Model is not a PeftModel"}
|
||||
return {"success": True, "result": self.model.peft_config, "error_msg": ""}
|
||||
@@ -1388,6 +1414,12 @@ class LLM(LLMInterface[_M, _T], ReprMixin):
|
||||
"success": False,
|
||||
"error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'",
|
||||
}
|
||||
if self.__llm_adapter_map__ is None:
|
||||
return {
|
||||
"success": False,
|
||||
"result": {},
|
||||
"error_msg": "No adapters available for current running server.",
|
||||
}
|
||||
if not isinstance(self.model, peft.PeftModel):
|
||||
return {"success": False, "error_msg": "Model is not a PeftModel"}
|
||||
try:
|
||||
@@ -1463,6 +1495,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin):
|
||||
"identifying_params": self.identifying_params,
|
||||
"llm": self, # NOTE: self reference to LLM
|
||||
"config": self.config,
|
||||
"download_model": self.ensure_model_id_exists,
|
||||
"__call__": _wrapped_generate_run,
|
||||
"__module__": f"openllm.models.{self.config['model_name']}",
|
||||
"__doc__": self.config["env"].start_docstring,
|
||||
|
||||
Reference in New Issue
Block a user