fix(falcon): loading based on model registration

remove duplicate events

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-06-06 22:42:28 -04:00
parent ffac6d8916
commit aa50b5279e
6 changed files with 15 additions and 67 deletions

View File

@@ -67,7 +67,7 @@ requires-python = ">=3.8"
[project.optional-dependencies]
all = ['openllm[fine-tune]', 'openllm[chatglm]', 'openllm[falcon]', 'openllm[flan-t5]', 'openllm[starcoder]']
chatglm = ['cpm_kernels', 'sentencepiece']
falcon = ['einops']
falcon = ['einops', 'xformers', 'safetensors']
fine-tune = ['peft', 'bitsandbytes', 'datasets']
flan-t5 = ['flax', 'jax', 'jaxlib', 'tensorflow']
starcoder = ['bitsandbytes']

View File

@@ -642,7 +642,6 @@ class LLM(LLMInterface, metaclass=LLMMetaclass):
kwds["accelerator"] = "bettertransformer"
if self.__llm_model__ is None:
# Hmm, bentoml.transformers.load_model doesn't yet support args.
self.__llm_model__ = self._bentomodel.load_model(*self.__llm_args__, **kwds)
if (

View File

@@ -86,7 +86,7 @@ def construct_python_options(llm: openllm.LLM, llm_fs: FS) -> PythonOptions:
"protobuf",
"grpcio",
"grpcio-health-checking",
"opentelemetry-instrumentation-grpc==0.35b0",
"opentelemetry-instrumentation-grpc==0.38b0",
"grpcio-reflection",
]
)

View File

@@ -129,22 +129,16 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
start_time = time.time_ns()
def get_tracking_event(return_value: t.Any):
assert group.name, "Group name is required"
if group.name in analytics.cli_events_map and command_name in analytics.cli_events_map[group.name]:
return analytics.cli_events_map[group.name][command_name](group, command_name, return_value)
return analytics.OpenllmCliEvent(cmd_group=group.name, cmd_name=command_name)
with analytics.set_bentoml_tracking():
assert group.name is not None, "group.name should not be None"
event = analytics.OpenllmCliEvent(cmd_group=group.name, cmd_name=command_name)
try:
return_value = func(*args, **attrs)
event = get_tracking_event(return_value)
duration_in_ms = (time.time_ns() - start_time) / 1e6
event.duration_in_ms = duration_in_ms
analytics.track(event)
return return_value
except Exception as e:
event = get_tracking_event(None)
duration_in_ms = (time.time_ns() - start_time) / 1e6
event.duration_in_ms = duration_in_ms
event.error_type = type(e).__name__
@@ -580,7 +574,7 @@ def cli_factory() -> click.Group:
if output == "pretty":
if not get_quiet_mode():
_echo("\n" + OPENLLM_FIGLET)
_echo("\n" + OPENLLM_FIGLET, fg="white")
if not _previously_built:
_echo(f"Successfully built {bento}.", fg="green")
else:

View File

@@ -35,7 +35,7 @@ class Falcon(openllm.LLM):
default_model = "tiiuae/falcon-7b"
requirements = ["einops"]
requirements = ["einops", "xformers", "safetensors"]
pretrained = ["tiiuae/falcon-7b", "tiiuae/falcon-40b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct"]
@@ -49,16 +49,15 @@ class Falcon(openllm.LLM):
device_map = attrs.pop("device_map", "auto")
tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, device_map=device_map
)
config = transformers.AutoConfig.from_pretrained(pretrained, trust_remote_code=trust_remote_code)
transformers.AutoModelForCausalLM.register(config.__class__, model.__class__)
return bentoml.transformers.save_model(
tag,
transformers.pipeline("text-generation", model=model, tokenizer=tokenizer),
custom_objects={"tokenizer": tokenizer},
pipeline = transformers.pipeline(
"text-generation",
model=pretrained,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
device_map=device_map,
tokenizer=tokenizer,
)
return bentoml.transformers.save_model(tag, pipeline, custom_objects={"tokenizer": tokenizer})
def sanitize_parameters(
self,
@@ -67,7 +66,7 @@ class Falcon(openllm.LLM):
top_k: int | None = None,
num_return_sequences: int | None = None,
eos_token_id: int | None = None,
use_default_prompt_template: bool = True,
use_default_prompt_template: bool = False,
**attrs: t.Any,
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
if use_default_prompt_template:

View File

@@ -22,16 +22,13 @@ import contextlib
import functools
import os
import typing as t
from datetime import datetime
import attr
import bentoml
from bentoml._internal.utils import analytics as _internal_analytics
from bentoml._internal.utils.analytics import usage_stats as _internal_usage
if t.TYPE_CHECKING:
import openllm
import click
from ..__about__ import __version__
@@ -77,15 +74,6 @@ class OpenllmCliEvent(_internal_analytics.schemas.EventMeta):
return_code: int = attr.field(default=None)
if t.TYPE_CHECKING:
T_con = t.TypeVar("T_con", contravariant=True)
class HandlerProtocol(t.Protocol[T_con]):
@staticmethod
def __call__(group: click.Group, cmd_name: str, return_value: T_con | None = None) -> OpenllmCliEvent:
...
@attr.define
class StartInitEvent(_internal_analytics.schemas.EventMeta):
model_name: str
@@ -111,35 +99,3 @@ def track_start_init(
if do_not_track():
return
track(StartInitEvent.handler(llm_config, supported_gpu))
@attr.define
class BuildEvent(OpenllmCliEvent):
bento_creation_timestamp: datetime = attr.field(default=None)
bento_size_in_gb: float = attr.field(default=0)
model_size_in_gb: float = attr.field(default=0)
model_type: str = attr.field(default=None)
model_framework: str = attr.field(default=None)
@staticmethod
def handler(group: click.Group, cmd_name: str, return_value: bentoml.Bento | None = None) -> BuildEvent:
from bentoml._internal.utils import calc_dir_size
assert group.name is not None, "group name should not be None"
if return_value is not None:
bento = return_value
return BuildEvent(
group.name,
cmd_name,
bento_creation_timestamp=bento.info.creation_time,
bento_size_in_gb=calc_dir_size(bento.path) / 1024**3,
model_size_in_gb=calc_dir_size(bento.path_of("/models")) / 1024**3,
model_type=bento.info.labels["_type"],
model_framework=bento.info.labels["_framework"],
)
return BuildEvent(group.name, cmd_name)
cli_events_map: dict[str, dict[str, HandlerProtocol[t.Any]]] = {
"openllm": {"build": BuildEvent.handler, "bundle": BuildEvent.handler}
}