fix(gptq): use upstream integration (#297)

* wip

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>

* feat: GPTQ transformers integration

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>

* fix: only load if variable is available and add changelog

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>

* chore: remove boilerplate check

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-09-04 14:05:50 -04:00
committed by GitHub
parent 3da869e728
commit 956b3a53bc
23 changed files with 197 additions and 248 deletions

View File

@@ -54,7 +54,6 @@ import openllm
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelStore
from openllm import bundle
from openllm import serialisation
from openllm.exceptions import OpenLLMException
from openllm.models.auto import CONFIG_MAPPING
from openllm.models.auto import MODEL_FLAX_MAPPING_NAMES
@@ -67,6 +66,7 @@ from openllm.utils import infer_auto_class
from openllm_core._typing_compat import Concatenate
from openllm_core._typing_compat import DictStrAny
from openllm_core._typing_compat import LiteralBackend
from openllm_core._typing_compat import LiteralQuantise
from openllm_core._typing_compat import LiteralString
from openllm_core._typing_compat import ParamSpec
from openllm_core._typing_compat import Self
@@ -84,7 +84,6 @@ from openllm_core.utils import first_not_none
from openllm_core.utils import get_debug_mode
from openllm_core.utils import get_quiet_mode
from openllm_core.utils import is_torch_available
from openllm_core.utils import is_transformers_supports_agent
from openllm_core.utils import resolve_user_filepath
from openllm_core.utils import set_debug_mode
from openllm_core.utils import set_quiet_mode
@@ -343,8 +342,8 @@ def import_command(
output: LiteralOutput,
machine: bool,
backend: LiteralBackend,
quantize: t.Literal['int8', 'int4', 'gptq'] | None,
serialisation_format: t.Literal['safetensors', 'legacy'],
quantize: LiteralQuantise | None,
serialisation: t.Literal['safetensors', 'legacy'],
) -> bentoml.Model:
"""Setup LLM interactively.
@@ -369,7 +368,7 @@ def import_command(
\b
```bash
$ openllm download opt facebook/opt-2.7b
$ openllm import opt facebook/opt-2.7b
```
\b
@@ -400,17 +399,19 @@ def import_command(
env = EnvVarMixin(model_name, backend=llm_config.default_backend(), model_id=model_id, quantize=quantize)
backend = first_not_none(backend, default=env['backend_value'])
llm = infer_auto_class(backend).for_model(
model_name, model_id=env['model_id_value'], llm_config=llm_config, model_version=model_version, ensure_available=False, serialisation=serialisation_format
model_name, model_id=env['model_id_value'], llm_config=llm_config, model_version=model_version, ensure_available=False,
quantize=env['quantize_value'],
serialisation=serialisation
)
_previously_saved = False
try:
_ref = serialisation.get(llm)
_ref = openllm.serialisation.get(llm)
_previously_saved = True
except openllm.exceptions.OpenLLMException:
if not machine and output == 'pretty':
msg = f"'{model_name}' {'with model_id='+ model_id if model_id is not None else ''} does not exists in local store for backend {llm.__llm_backend__}. Saving to BENTOML_HOME{' (path=' + os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..."
termui.echo(msg, fg='yellow', nl=True)
_ref = serialisation.get(llm, auto_import=True)
_ref = openllm.serialisation.get(llm, auto_import=True)
if backend == 'pt' and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
if machine: return _ref
elif output == 'pretty':
@@ -472,7 +473,7 @@ def build_command(
bento_version: str | None,
overwrite: bool,
output: LiteralOutput,
quantize: t.Literal['int8', 'int4', 'gptq'] | None,
quantize: LiteralQuantise | None,
enable_features: tuple[str, ...] | None,
workers_per_resource: float | None,
adapter_id: tuple[str, ...],
@@ -483,7 +484,7 @@ def build_command(
dockerfile_template: t.TextIO | None,
containerize: bool,
push: bool,
serialisation_format: t.Literal['safetensors', 'legacy'],
serialisation: t.Literal['safetensors', 'legacy'],
container_registry: LiteralContainerRegistry,
container_version_strategy: LiteralContainerVersionStrategy,
force_push: bool,
@@ -517,12 +518,12 @@ def build_command(
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
try:
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': serialisation_format, 'OPENLLM_BACKEND': env['backend_value']})
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': serialisation, env.backend: env['backend_value']})
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
llm = infer_auto_class(env['backend_value']).for_model(
model_name, model_id=env['model_id_value'], llm_config=llm_config, ensure_available=True, model_version=model_version, serialisation=serialisation_format, **attrs
model_name, model_id=env['model_id_value'], llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=serialisation, **attrs
)
labels = dict(llm.identifying_params)
@@ -798,7 +799,6 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output:
except http.client.BadStatusLine:
raise click.ClickException(f'{endpoint} is neither a HTTP server nor reachable.') from None
if agent == 'hf':
if not is_transformers_supports_agent(): raise click.UsageError("Transformers version should be at least 4.29 to support HfAgent. Upgrade with 'pip install -U transformers'")
_memoized = {k: v[0] for k, v in _memoized.items() if v}
client._hf_agent.set_stream(logger.info)
if output != 'porcelain': termui.echo(f"Sending the following prompt ('{task}') with the following vars: {_memoized}", fg='magenta')