mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-30 18:32:18 -05:00
fix(gptq): use upstream integration (#297)
* wip Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> * feat: GPTQ transformers integration Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> * fix: only load if variable is available and add changelog Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> * chore: remove boilerplate check Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm_core._typing_compat import Concatenate
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
from openllm_core._typing_compat import LiteralQuantise
|
||||
from openllm_core._typing_compat import LiteralString
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import get_literal_args
|
||||
@@ -131,15 +132,15 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
model_version: str | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None,
|
||||
quantize: LiteralQuantise | None,
|
||||
backend: LiteralBackend,
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'],
|
||||
serialisation: t.Literal['safetensors', 'legacy'],
|
||||
cors: bool,
|
||||
adapter_id: str | None,
|
||||
return_process: bool,
|
||||
**attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
if serialisation_format == 'safetensors' and quantize is not None and openllm_core.utils.check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
|
||||
if serialisation == 'safetensors' and quantize is not None and openllm_core.utils.check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
|
||||
termui.echo(
|
||||
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. Use with caution!. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.",
|
||||
fg='yellow')
|
||||
@@ -184,11 +185,11 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
|
||||
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
|
||||
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
|
||||
'OPENLLM_SERIALIZATION': serialisation_format,
|
||||
env.backend: env['backend_value']
|
||||
'OPENLLM_SERIALIZATION': serialisation,
|
||||
env.backend: env['backend_value'],
|
||||
})
|
||||
if env['model_id_value']: start_env[env.model_id] = str(env['model_id_value'])
|
||||
if quantize is not None: start_env[env.quantize] = str(t.cast(str, env['quantize_value']))
|
||||
if env['quantize_value']: start_env[env.quantize] = str(env['quantize_value'])
|
||||
|
||||
llm = openllm.utils.infer_auto_class(env['backend_value']).for_model(model,
|
||||
model_id=start_env[env.model_id],
|
||||
@@ -196,7 +197,8 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
llm_config=config,
|
||||
ensure_available=True,
|
||||
adapter_map=adapter_map,
|
||||
serialisation=serialisation_format)
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=serialisation)
|
||||
start_env.update({env.config: llm.config.model_dump_json().decode()})
|
||||
|
||||
server = bentoml.GrpcServer('_service:svc', **server_attrs) if _serve_grpc else bentoml.HTTPServer('_service:svc', **server_attrs)
|
||||
@@ -262,8 +264,7 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
|
||||
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
''',
|
||||
),
|
||||
'''),
|
||||
quantize_option(factory=cog.optgroup),
|
||||
serialisation_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--device',
|
||||
@@ -457,7 +458,7 @@ def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool =
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--serialisation',
|
||||
'--serialization',
|
||||
'serialisation_format',
|
||||
'serialisation',
|
||||
type=click.Choice(['safetensors', 'legacy']),
|
||||
default='safetensors',
|
||||
show_default=True,
|
||||
|
||||
@@ -26,6 +26,7 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry
|
||||
from openllm_core._typing_compat import LiteralContainerVersionStrategy
|
||||
from openllm_core._typing_compat import LiteralQuantise
|
||||
from openllm_core._typing_compat import LiteralString
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,7 +38,7 @@ def _start(model_name: str,
|
||||
timeout: int = 30,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
@@ -109,7 +110,7 @@ def _build(model_name: str,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
bento_version: str | None = None,
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
@@ -120,7 +121,7 @@ def _build(model_name: str,
|
||||
container_version_strategy: LiteralContainerVersionStrategy | None = None,
|
||||
push: bool = False,
|
||||
containerize: bool = False,
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'] = 'safetensors',
|
||||
serialisation: t.Literal['safetensors', 'legacy'] = 'safetensors',
|
||||
additional_args: list[str] | None = None,
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> bentoml.Bento:
|
||||
"""Package a LLM into a Bento.
|
||||
@@ -160,14 +161,14 @@ def _build(model_name: str,
|
||||
container_registry: Container registry to choose the base OpenLLM container image to build from. Default to ECR.
|
||||
container_registry: Container registry to choose the base OpenLLM container image to build from. Default to ECR.
|
||||
container_version_strategy: The container version strategy. Default to the latest release of OpenLLM.
|
||||
serialisation_format: Serialisation for saving models. Default to 'safetensors', which is equivalent to `safe_serialization=True`
|
||||
serialisation: Serialisation for saving models. Default to 'safetensors', which is equivalent to `safe_serialization=True`
|
||||
additional_args: Additional arguments to pass to ``openllm build``.
|
||||
bento_store: Optional BentoStore for saving this BentoLLM. Default to the default BentoML local store.
|
||||
|
||||
Returns:
|
||||
``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
args: list[str] = [sys.executable, '-m', 'openllm', 'build', model_name, '--machine', '--serialisation', serialisation_format]
|
||||
args: list[str] = [sys.executable, '-m', 'openllm', 'build', model_name, '--machine', '--serialisation', serialisation]
|
||||
if quantize: args.extend(['--quantize', quantize])
|
||||
if containerize and push: raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
|
||||
if push: args.extend(['--push'])
|
||||
@@ -203,8 +204,8 @@ def _import_model(model_name: str,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
backend: LiteralBackend = 'pt',
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
serialisation_format: t.Literal['legacy', 'safetensors'] = 'safetensors',
|
||||
quantize: LiteralQuantise | None = None,
|
||||
serialisation: t.Literal['legacy', 'safetensors'] = 'safetensors',
|
||||
additional_args: t.Sequence[str] | None = None) -> bentoml.Model:
|
||||
"""Import a LLM into local store.
|
||||
|
||||
@@ -228,7 +229,7 @@ def _import_model(model_name: str,
|
||||
- int8: Quantize the model with 8bit (bitsandbytes required)
|
||||
- int4: Quantize the model with 4bit (bitsandbytes required)
|
||||
- gptq: Quantize the model with GPTQ (auto-gptq required)
|
||||
serialisation_format: Type of model format to save to local store. If set to 'safetensors', then OpenLLM will save model using safetensors.
|
||||
serialisation: Type of model format to save to local store. If set to 'safetensors', then OpenLLM will save model using safetensors.
|
||||
Default behaviour is similar to ``safe_serialization=False``.
|
||||
additional_args: Additional arguments to pass to ``openllm import``.
|
||||
|
||||
@@ -236,7 +237,7 @@ def _import_model(model_name: str,
|
||||
``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from .entrypoint import import_command
|
||||
args = [model_name, '--backend', backend, '--machine', '--serialisation', serialisation_format]
|
||||
args = [model_name, '--backend', backend, '--machine', '--serialisation', serialisation]
|
||||
if model_id is not None: args.append(model_id)
|
||||
if model_version is not None: args.extend(['--model-version', str(model_version)])
|
||||
if additional_args is not None: args.extend(additional_args)
|
||||
|
||||
@@ -54,7 +54,6 @@ import openllm
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelStore
|
||||
from openllm import bundle
|
||||
from openllm import serialisation
|
||||
from openllm.exceptions import OpenLLMException
|
||||
from openllm.models.auto import CONFIG_MAPPING
|
||||
from openllm.models.auto import MODEL_FLAX_MAPPING_NAMES
|
||||
@@ -67,6 +66,7 @@ from openllm.utils import infer_auto_class
|
||||
from openllm_core._typing_compat import Concatenate
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
from openllm_core._typing_compat import LiteralQuantise
|
||||
from openllm_core._typing_compat import LiteralString
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import Self
|
||||
@@ -84,7 +84,6 @@ from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import get_debug_mode
|
||||
from openllm_core.utils import get_quiet_mode
|
||||
from openllm_core.utils import is_torch_available
|
||||
from openllm_core.utils import is_transformers_supports_agent
|
||||
from openllm_core.utils import resolve_user_filepath
|
||||
from openllm_core.utils import set_debug_mode
|
||||
from openllm_core.utils import set_quiet_mode
|
||||
@@ -343,8 +342,8 @@ def import_command(
|
||||
output: LiteralOutput,
|
||||
machine: bool,
|
||||
backend: LiteralBackend,
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None,
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'],
|
||||
quantize: LiteralQuantise | None,
|
||||
serialisation: t.Literal['safetensors', 'legacy'],
|
||||
) -> bentoml.Model:
|
||||
"""Setup LLM interactively.
|
||||
|
||||
@@ -369,7 +368,7 @@ def import_command(
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm download opt facebook/opt-2.7b
|
||||
$ openllm import opt facebook/opt-2.7b
|
||||
```
|
||||
|
||||
\b
|
||||
@@ -400,17 +399,19 @@ def import_command(
|
||||
env = EnvVarMixin(model_name, backend=llm_config.default_backend(), model_id=model_id, quantize=quantize)
|
||||
backend = first_not_none(backend, default=env['backend_value'])
|
||||
llm = infer_auto_class(backend).for_model(
|
||||
model_name, model_id=env['model_id_value'], llm_config=llm_config, model_version=model_version, ensure_available=False, serialisation=serialisation_format
|
||||
model_name, model_id=env['model_id_value'], llm_config=llm_config, model_version=model_version, ensure_available=False,
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=serialisation
|
||||
)
|
||||
_previously_saved = False
|
||||
try:
|
||||
_ref = serialisation.get(llm)
|
||||
_ref = openllm.serialisation.get(llm)
|
||||
_previously_saved = True
|
||||
except openllm.exceptions.OpenLLMException:
|
||||
if not machine and output == 'pretty':
|
||||
msg = f"'{model_name}' {'with model_id='+ model_id if model_id is not None else ''} does not exists in local store for backend {llm.__llm_backend__}. Saving to BENTOML_HOME{' (path=' + os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..."
|
||||
termui.echo(msg, fg='yellow', nl=True)
|
||||
_ref = serialisation.get(llm, auto_import=True)
|
||||
_ref = openllm.serialisation.get(llm, auto_import=True)
|
||||
if backend == 'pt' and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
if machine: return _ref
|
||||
elif output == 'pretty':
|
||||
@@ -472,7 +473,7 @@ def build_command(
|
||||
bento_version: str | None,
|
||||
overwrite: bool,
|
||||
output: LiteralOutput,
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None,
|
||||
quantize: LiteralQuantise | None,
|
||||
enable_features: tuple[str, ...] | None,
|
||||
workers_per_resource: float | None,
|
||||
adapter_id: tuple[str, ...],
|
||||
@@ -483,7 +484,7 @@ def build_command(
|
||||
dockerfile_template: t.TextIO | None,
|
||||
containerize: bool,
|
||||
push: bool,
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'],
|
||||
serialisation: t.Literal['safetensors', 'legacy'],
|
||||
container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy,
|
||||
force_push: bool,
|
||||
@@ -517,12 +518,12 @@ def build_command(
|
||||
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
|
||||
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
|
||||
try:
|
||||
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': serialisation_format, 'OPENLLM_BACKEND': env['backend_value']})
|
||||
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': serialisation, env.backend: env['backend_value']})
|
||||
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
|
||||
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
|
||||
|
||||
llm = infer_auto_class(env['backend_value']).for_model(
|
||||
model_name, model_id=env['model_id_value'], llm_config=llm_config, ensure_available=True, model_version=model_version, serialisation=serialisation_format, **attrs
|
||||
model_name, model_id=env['model_id_value'], llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=serialisation, **attrs
|
||||
)
|
||||
|
||||
labels = dict(llm.identifying_params)
|
||||
@@ -798,7 +799,6 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output:
|
||||
except http.client.BadStatusLine:
|
||||
raise click.ClickException(f'{endpoint} is neither a HTTP server nor reachable.') from None
|
||||
if agent == 'hf':
|
||||
if not is_transformers_supports_agent(): raise click.UsageError("Transformers version should be at least 4.29 to support HfAgent. Upgrade with 'pip install -U transformers'")
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
client._hf_agent.set_stream(logger.info)
|
||||
if output != 'porcelain': termui.echo(f"Sending the following prompt ('{task}') with the following vars: {_memoized}", fg='magenta')
|
||||
|
||||
Reference in New Issue
Block a user