mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-01 05:36:49 -04:00
feat(quantization): add quantization backend (#9096)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
4b183b7bb6
commit
f7e8d9e791
26
backend/python/llama-cpp-quantization/Makefile
Normal file
26
backend/python/llama-cpp-quantization/Makefile
Normal file
@@ -0,0 +1,26 @@
|
||||
# Version of llama.cpp to fetch convert_hf_to_gguf.py from
|
||||
LLAMA_CPP_CONVERT_VERSION ?= master
|
||||
|
||||
.PHONY: llama-cpp-quantization
|
||||
llama-cpp-quantization:
|
||||
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: llama-cpp-quantization
|
||||
@echo "Running llama-cpp-quantization..."
|
||||
bash run.sh
|
||||
@echo "llama-cpp-quantization run."
|
||||
|
||||
.PHONY: test
|
||||
test: llama-cpp-quantization
|
||||
@echo "Testing llama-cpp-quantization..."
|
||||
bash test.sh
|
||||
@echo "llama-cpp-quantization tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
420
backend/python/llama-cpp-quantization/backend.py
Normal file
420
backend/python/llama-cpp-quantization/backend.py
Normal file
@@ -0,0 +1,420 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
llama.cpp quantization backend for LocalAI.
|
||||
|
||||
Downloads HuggingFace models, converts them to GGUF format using
|
||||
convert_hf_to_gguf.py, and quantizes using llama-quantize.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
|
||||
|
||||
|
||||
class ActiveJob:
|
||||
"""Tracks a running quantization job."""
|
||||
def __init__(self, job_id):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = queue.Queue()
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = None
|
||||
self.process = None # subprocess handle for killing
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self):
|
||||
self.jobs = {} # job_id -> ActiveJob
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=b"OK")
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Accept LoadModel — actual work happens in StartQuantization."""
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def StartQuantization(self, request, context):
|
||||
job_id = request.job_id
|
||||
if job_id in self.jobs:
|
||||
return backend_pb2.QuantizationJobResult(
|
||||
job_id=job_id,
|
||||
success=False,
|
||||
message=f"Job {job_id} already exists",
|
||||
)
|
||||
|
||||
job = ActiveJob(job_id)
|
||||
self.jobs[job_id] = job
|
||||
|
||||
job.thread = threading.Thread(
|
||||
target=self._do_quantization,
|
||||
args=(job, request),
|
||||
daemon=True,
|
||||
)
|
||||
job.thread.start()
|
||||
|
||||
return backend_pb2.QuantizationJobResult(
|
||||
job_id=job_id,
|
||||
success=True,
|
||||
message="Quantization job started",
|
||||
)
|
||||
|
||||
def _send_progress(self, job, status, message, progress_percent=0.0, output_file="", extra_metrics=None):
|
||||
update = backend_pb2.QuantizationProgressUpdate(
|
||||
job_id=job.job_id,
|
||||
progress_percent=progress_percent,
|
||||
status=status,
|
||||
message=message,
|
||||
output_file=output_file,
|
||||
extra_metrics=extra_metrics or {},
|
||||
)
|
||||
job.progress_queue.put(update)
|
||||
|
||||
def _do_quantization(self, job, request):
|
||||
try:
|
||||
model = request.model
|
||||
quant_type = request.quantization_type or "q4_k_m"
|
||||
output_dir = request.output_dir
|
||||
extra_options = dict(request.extra_options) if request.extra_options else {}
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped before starting")
|
||||
return
|
||||
|
||||
# Step 1: Download / resolve model
|
||||
self._send_progress(job, "downloading", f"Resolving model: {model}", progress_percent=0.0)
|
||||
|
||||
model_path = self._resolve_model(job, model, output_dir, extra_options)
|
||||
if model_path is None:
|
||||
return # error already sent
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped during download")
|
||||
return
|
||||
|
||||
# Step 2: Convert to f16 GGUF
|
||||
self._send_progress(job, "converting", "Converting model to GGUF (f16)...", progress_percent=30.0)
|
||||
|
||||
f16_gguf_path = os.path.join(output_dir, "model-f16.gguf")
|
||||
if not self._convert_to_gguf(job, model_path, f16_gguf_path, extra_options):
|
||||
return # error already sent
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped during conversion")
|
||||
return
|
||||
|
||||
# Step 3: Quantize
|
||||
# If the user requested f16, skip quantization — the f16 GGUF is the final output
|
||||
if quant_type.lower() in ("f16", "fp16"):
|
||||
output_file = f16_gguf_path
|
||||
self._send_progress(
|
||||
job, "completed",
|
||||
f"Model converted to f16 GGUF: {output_file}",
|
||||
progress_percent=100.0,
|
||||
output_file=output_file,
|
||||
extra_metrics=self._file_metrics(output_file),
|
||||
)
|
||||
return
|
||||
|
||||
output_file = os.path.join(output_dir, f"model-{quant_type}.gguf")
|
||||
self._send_progress(job, "quantizing", f"Quantizing to {quant_type}...", progress_percent=50.0)
|
||||
|
||||
if not self._quantize(job, f16_gguf_path, output_file, quant_type):
|
||||
return # error already sent
|
||||
|
||||
# Clean up f16 intermediate file to save disk space
|
||||
try:
|
||||
os.remove(f16_gguf_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
self._send_progress(
|
||||
job, "completed",
|
||||
f"Quantization complete: {quant_type}",
|
||||
progress_percent=100.0,
|
||||
output_file=output_file,
|
||||
extra_metrics=self._file_metrics(output_file),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
|
||||
|
||||
def _resolve_model(self, job, model, output_dir, extra_options):
|
||||
"""Download model from HuggingFace or return local path."""
|
||||
# If it's a local path that exists, use it directly
|
||||
if os.path.isdir(model):
|
||||
return model
|
||||
|
||||
# If it looks like a GGUF file path, use it directly
|
||||
if os.path.isfile(model) and model.endswith(".gguf"):
|
||||
return model
|
||||
|
||||
# Download from HuggingFace
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
hf_token = extra_options.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
cache_dir = os.path.join(output_dir, "hf_cache")
|
||||
|
||||
self._send_progress(job, "downloading", f"Downloading {model} from HuggingFace...", progress_percent=5.0)
|
||||
|
||||
local_path = snapshot_download(
|
||||
repo_id=model,
|
||||
cache_dir=cache_dir,
|
||||
token=hf_token,
|
||||
ignore_patterns=["*.md", "*.txt", "LICENSE*", ".gitattributes"],
|
||||
)
|
||||
|
||||
self._send_progress(job, "downloading", f"Downloaded {model}", progress_percent=25.0)
|
||||
return local_path
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "gated" in error_msg.lower() or "access" in error_msg.lower():
|
||||
self._send_progress(
|
||||
job, "failed",
|
||||
f"Access denied for {model}. This model may be gated — "
|
||||
f"please accept the license at https://huggingface.co/{model} "
|
||||
f"and provide your HF token in extra_options.",
|
||||
)
|
||||
else:
|
||||
self._send_progress(job, "failed", f"Failed to download model: {error_msg}")
|
||||
return None
|
||||
|
||||
def _convert_to_gguf(self, job, model_path, output_path, extra_options):
|
||||
"""Convert HF model to f16 GGUF using convert_hf_to_gguf.py."""
|
||||
# If the model_path is already a GGUF file, just use it as-is
|
||||
if isinstance(model_path, str) and model_path.endswith(".gguf"):
|
||||
# Copy or symlink the GGUF file
|
||||
import shutil
|
||||
shutil.copy2(model_path, output_path)
|
||||
return True
|
||||
|
||||
# Find convert_hf_to_gguf.py
|
||||
convert_script = self._find_convert_script()
|
||||
if convert_script is None:
|
||||
self._send_progress(job, "failed", "convert_hf_to_gguf.py not found. Install it via the backend's install.sh.")
|
||||
return False
|
||||
|
||||
cmd = [
|
||||
sys.executable, convert_script,
|
||||
model_path,
|
||||
"--outfile", output_path,
|
||||
"--outtype", "f16",
|
||||
]
|
||||
|
||||
self._send_progress(job, "converting", "Running convert_hf_to_gguf.py...", progress_percent=35.0)
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
job.process = process
|
||||
|
||||
for line in process.stdout:
|
||||
line = line.strip()
|
||||
if line:
|
||||
self._send_progress(job, "converting", line, progress_percent=40.0)
|
||||
if job.stop_event.is_set():
|
||||
process.kill()
|
||||
self._send_progress(job, "stopped", "Job stopped during conversion")
|
||||
return False
|
||||
|
||||
process.wait()
|
||||
job.process = None
|
||||
|
||||
if process.returncode != 0:
|
||||
self._send_progress(job, "failed", f"convert_hf_to_gguf.py failed with exit code {process.returncode}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Conversion failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def _quantize(self, job, input_path, output_path, quant_type):
|
||||
"""Quantize a GGUF file using llama-quantize."""
|
||||
quantize_bin = self._find_quantize_binary()
|
||||
if quantize_bin is None:
|
||||
self._send_progress(job, "failed", "llama-quantize binary not found. Ensure it is installed and in PATH.")
|
||||
return False
|
||||
|
||||
cmd = [quantize_bin, input_path, output_path, quant_type]
|
||||
|
||||
self._send_progress(job, "quantizing", f"Running llama-quantize ({quant_type})...", progress_percent=55.0)
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
job.process = process
|
||||
|
||||
for line in process.stdout:
|
||||
line = line.strip()
|
||||
if line:
|
||||
# Try to parse progress from llama-quantize output
|
||||
progress = self._parse_quantize_progress(line)
|
||||
pct = 55.0 + (progress * 0.40) if progress else 60.0
|
||||
self._send_progress(job, "quantizing", line, progress_percent=pct)
|
||||
if job.stop_event.is_set():
|
||||
process.kill()
|
||||
self._send_progress(job, "stopped", "Job stopped during quantization")
|
||||
return False
|
||||
|
||||
process.wait()
|
||||
job.process = None
|
||||
|
||||
if process.returncode != 0:
|
||||
self._send_progress(job, "failed", f"llama-quantize failed with exit code {process.returncode}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def _parse_quantize_progress(self, line):
|
||||
"""Try to parse a progress percentage from llama-quantize output."""
|
||||
# llama-quantize typically outputs lines like:
|
||||
# [ 123/ 1234] quantizing blk.0.attn_k.weight ...
|
||||
match = re.search(r'\[\s*(\d+)\s*/\s*(\d+)\]', line)
|
||||
if match:
|
||||
current = int(match.group(1))
|
||||
total = int(match.group(2))
|
||||
if total > 0:
|
||||
return current / total
|
||||
return None
|
||||
|
||||
def _find_convert_script(self):
|
||||
"""Find convert_hf_to_gguf.py in known locations."""
|
||||
candidates = [
|
||||
# Same directory as this backend
|
||||
os.path.join(os.path.dirname(__file__), "convert_hf_to_gguf.py"),
|
||||
# Installed via install.sh
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "convert_hf_to_gguf.py"),
|
||||
]
|
||||
|
||||
# Also check if it's on PATH
|
||||
import shutil
|
||||
path_script = shutil.which("convert_hf_to_gguf.py")
|
||||
if path_script:
|
||||
candidates.append(path_script)
|
||||
|
||||
for candidate in candidates:
|
||||
if os.path.isfile(candidate):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
def _find_quantize_binary(self):
|
||||
"""Find llama-quantize binary."""
|
||||
import shutil
|
||||
|
||||
# Check common names on PATH
|
||||
for name in ["llama-quantize", "quantize"]:
|
||||
path = shutil.which(name)
|
||||
if path:
|
||||
return path
|
||||
|
||||
# Check in the backend directory (built by install.sh)
|
||||
backend_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
for name in ["llama-quantize", "quantize"]:
|
||||
candidate = os.path.join(backend_dir, name)
|
||||
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
def _file_metrics(self, filepath):
|
||||
"""Return file size metrics."""
|
||||
try:
|
||||
size_bytes = os.path.getsize(filepath)
|
||||
return {"file_size_mb": size_bytes / (1024 * 1024)}
|
||||
except OSError:
|
||||
return {}
|
||||
|
||||
def QuantizationProgress(self, request, context):
|
||||
job_id = request.job_id
|
||||
job = self.jobs.get(job_id)
|
||||
if job is None:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Job {job_id} not found")
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
update = job.progress_queue.get(timeout=1.0)
|
||||
yield update
|
||||
# If this is a terminal status, stop streaming
|
||||
if update.status in ("completed", "failed", "stopped"):
|
||||
break
|
||||
except queue.Empty:
|
||||
# Check if the thread is still alive
|
||||
if job.thread and not job.thread.is_alive():
|
||||
# Thread finished but no terminal update — drain queue
|
||||
while not job.progress_queue.empty():
|
||||
update = job.progress_queue.get_nowait()
|
||||
yield update
|
||||
break
|
||||
# Check if client disconnected
|
||||
if context.is_active() is False:
|
||||
break
|
||||
|
||||
def StopQuantization(self, request, context):
|
||||
job_id = request.job_id
|
||||
job = self.jobs.get(job_id)
|
||||
if job is None:
|
||||
return backend_pb2.Result(success=False, message=f"Job {job_id} not found")
|
||||
|
||||
job.stop_event.set()
|
||||
if job.process:
|
||||
try:
|
||||
job.process.kill()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return backend_pb2.Result(success=True, message="Stop signal sent")
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print(f"Quantization backend listening on {address}", file=sys.stderr, flush=True)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="llama.cpp quantization gRPC backend")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||
args = parser.parse_args()
|
||||
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: sys.exit(0))
|
||||
serve(args.addr)
|
||||
58
backend/python/llama-cpp-quantization/install.sh
Executable file
58
backend/python/llama-cpp-quantization/install.sh
Executable file
@@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade "
|
||||
installRequirements
|
||||
|
||||
# Fetch convert_hf_to_gguf.py from llama.cpp
|
||||
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
||||
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
||||
if [ ! -f "${CONVERT_SCRIPT}" ]; then
|
||||
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
curl -L --fail --retry 3 \
|
||||
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
|
||||
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py."
|
||||
fi
|
||||
|
||||
# Install gguf package from the same llama.cpp commit to keep them in sync
|
||||
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
|
||||
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
if [ "x${USE_PIP:-}" == "xtrue" ]; then
|
||||
pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
pip install "gguf>=0.16.0"
|
||||
}
|
||||
else
|
||||
uv pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
uv pip install "gguf>=0.16.0"
|
||||
}
|
||||
fi
|
||||
|
||||
# Build llama-quantize from llama.cpp if not already present
|
||||
QUANTIZE_BIN="${EDIR}/llama-quantize"
|
||||
if [ ! -x "${QUANTIZE_BIN}" ] && ! command -v llama-quantize &>/dev/null; then
|
||||
if command -v cmake &>/dev/null; then
|
||||
echo "Building llama-quantize from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
LLAMA_CPP_SRC="${EDIR}/llama.cpp"
|
||||
if [ ! -d "${LLAMA_CPP_SRC}" ]; then
|
||||
git clone --depth 1 --branch "${LLAMA_CPP_CONVERT_VERSION}" \
|
||||
https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}" 2>/dev/null || \
|
||||
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}"
|
||||
fi
|
||||
cmake -B "${LLAMA_CPP_SRC}/build" -S "${LLAMA_CPP_SRC}" -DGGML_NATIVE=OFF -DBUILD_SHARED_LIBS=OFF
|
||||
cmake --build "${LLAMA_CPP_SRC}/build" --target llama-quantize -j"$(nproc 2>/dev/null || echo 2)"
|
||||
cp "${LLAMA_CPP_SRC}/build/bin/llama-quantize" "${QUANTIZE_BIN}"
|
||||
chmod +x "${QUANTIZE_BIN}"
|
||||
echo "Built llama-quantize at ${QUANTIZE_BIN}"
|
||||
else
|
||||
echo "Warning: cmake not found — llama-quantize will not be available. Install cmake or provide llama-quantize on PATH."
|
||||
fi
|
||||
fi
|
||||
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.10.0
|
||||
transformers>=4.56.2
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
@@ -0,0 +1,4 @@
|
||||
torch==2.10.0
|
||||
transformers>=4.56.2
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
3
backend/python/llama-cpp-quantization/requirements.txt
Normal file
3
backend/python/llama-cpp-quantization/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
10
backend/python/llama-cpp-quantization/run.sh
Executable file
10
backend/python/llama-cpp-quantization/run.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
99
backend/python/llama-cpp-quantization/test.py
Normal file
99
backend/python/llama-cpp-quantization/test.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Test script for the llama-cpp-quantization gRPC backend.
|
||||
|
||||
Downloads a small model (functiongemma-270m-it), converts it to GGUF,
|
||||
and quantizes it to q4_k_m.
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
SERVER_ADDR = "localhost:50051"
|
||||
# Small model for CI testing (~540MB)
|
||||
TEST_MODEL = "unsloth/functiongemma-270m-it"
|
||||
|
||||
|
||||
class TestQuantizationBackend(unittest.TestCase):
|
||||
"""Tests for the llama-cpp-quantization gRPC service."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", SERVER_ADDR]
|
||||
)
|
||||
time.sleep(5)
|
||||
cls.output_dir = tempfile.mkdtemp(prefix="quantize-test-")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.service.kill()
|
||||
cls.service.wait()
|
||||
# Clean up output directory
|
||||
if os.path.isdir(cls.output_dir):
|
||||
shutil.rmtree(cls.output_dir, ignore_errors=True)
|
||||
|
||||
def _channel(self):
|
||||
return grpc.insecure_channel(SERVER_ADDR)
|
||||
|
||||
def test_01_health(self):
|
||||
"""Test that the server starts and responds to health checks."""
|
||||
with self._channel() as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b"OK")
|
||||
|
||||
def test_02_quantize_small_model(self):
|
||||
"""Download, convert, and quantize functiongemma-270m-it to q4_k_m."""
|
||||
with self._channel() as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
|
||||
job_id = "test-quantize-001"
|
||||
|
||||
# Start quantization
|
||||
result = stub.StartQuantization(
|
||||
backend_pb2.QuantizationRequest(
|
||||
model=TEST_MODEL,
|
||||
quantization_type="q4_k_m",
|
||||
output_dir=self.output_dir,
|
||||
job_id=job_id,
|
||||
)
|
||||
)
|
||||
self.assertTrue(result.success, f"StartQuantization failed: {result.message}")
|
||||
self.assertEqual(result.job_id, job_id)
|
||||
|
||||
# Stream progress until completion
|
||||
final_status = None
|
||||
output_file = None
|
||||
for update in stub.QuantizationProgress(
|
||||
backend_pb2.QuantizationProgressRequest(job_id=job_id)
|
||||
):
|
||||
print(f" [{update.status}] {update.progress_percent:.1f}% - {update.message}")
|
||||
final_status = update.status
|
||||
if update.output_file:
|
||||
output_file = update.output_file
|
||||
|
||||
self.assertEqual(final_status, "completed", f"Expected completed, got {final_status}")
|
||||
self.assertIsNotNone(output_file, "No output_file in progress updates")
|
||||
self.assertTrue(os.path.isfile(output_file), f"Output file not found: {output_file}")
|
||||
|
||||
# Verify the output is a valid GGUF file (starts with "GGUF" magic)
|
||||
with open(output_file, "rb") as f:
|
||||
magic = f.read(4)
|
||||
self.assertEqual(magic, b"GGUF", f"Output file does not have GGUF magic: {magic!r}")
|
||||
|
||||
# Verify reasonable file size (q4_k_m of 270M model should be ~150-400MB)
|
||||
size_mb = os.path.getsize(output_file) / (1024 * 1024)
|
||||
print(f" Output file size: {size_mb:.1f} MB")
|
||||
self.assertGreater(size_mb, 10, "Output file suspiciously small")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
11
backend/python/llama-cpp-quantization/test.sh
Executable file
11
backend/python/llama-cpp-quantization/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
Reference in New Issue
Block a user