From f7e8d9e791e5a441864ff7709756e2c7a8cdc67b Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 22 Mar 2026 00:56:34 +0100 Subject: [PATCH] feat(quantization): add quantization backend (#9096) Signed-off-by: Ettore Di Giacinto --- .github/workflows/backend.yml | 16 + .github/workflows/test-extra.yml | 27 + Makefile | 6 +- backend/backend.proto | 38 ++ backend/index.yaml | 28 + .../python/llama-cpp-quantization/Makefile | 26 + .../python/llama-cpp-quantization/backend.py | 420 +++++++++++++ .../python/llama-cpp-quantization/install.sh | 58 ++ .../requirements-cpu.txt | 5 + .../requirements-mps.txt | 4 + .../llama-cpp-quantization/requirements.txt | 3 + backend/python/llama-cpp-quantization/run.sh | 10 + backend/python/llama-cpp-quantization/test.py | 99 +++ backend/python/llama-cpp-quantization/test.sh | 11 + core/cli/run.go | 8 - core/config/application_config.go | 14 - core/http/app.go | 25 +- core/http/auth/features.go | 11 + core/http/auth/permissions.go | 5 +- core/http/endpoints/localai/quantization.go | 243 ++++++++ core/http/react-ui/src/components/Sidebar.jsx | 3 +- core/http/react-ui/src/pages/FineTune.jsx | 2 +- core/http/react-ui/src/pages/Quantize.jsx | 485 +++++++++++++++ core/http/react-ui/src/router.jsx | 2 + core/http/react-ui/src/utils/api.js | 13 + core/http/routes/localai.go | 7 +- core/http/routes/quantization.go | 40 ++ core/schema/quantization.go | 55 ++ core/services/quantization.go | 574 ++++++++++++++++++ docs/content/features/fine-tuning.md | 12 +- docs/content/features/quantization.md | 158 +++++ pkg/grpc/backend.go | 5 + pkg/grpc/base/base.go | 12 + pkg/grpc/client.go | 92 +++ pkg/grpc/embed.go | 54 ++ pkg/grpc/interface.go | 5 + pkg/grpc/server.go | 45 ++ 37 files changed, 2574 insertions(+), 47 deletions(-) create mode 100644 backend/python/llama-cpp-quantization/Makefile create mode 100644 backend/python/llama-cpp-quantization/backend.py create mode 100755 backend/python/llama-cpp-quantization/install.sh create mode 100644 backend/python/llama-cpp-quantization/requirements-cpu.txt create mode 100644 backend/python/llama-cpp-quantization/requirements-mps.txt create mode 100644 backend/python/llama-cpp-quantization/requirements.txt create mode 100755 backend/python/llama-cpp-quantization/run.sh create mode 100644 backend/python/llama-cpp-quantization/test.py create mode 100755 backend/python/llama-cpp-quantization/test.sh create mode 100644 core/http/endpoints/localai/quantization.go create mode 100644 core/http/react-ui/src/pages/Quantize.jsx create mode 100644 core/http/routes/quantization.go create mode 100644 core/schema/quantization.go create mode 100644 core/services/quantization.go create mode 100644 docs/content/features/quantization.md diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 6842d7da1..0ec9bcf58 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -131,6 +131,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64,linux/arm64' + tag-latest: 'auto' + tag-suffix: '-cpu-llama-cpp-quantization' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'true' + backend: "llama-cpp-quantization" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" @@ -2412,6 +2425,9 @@ jobs: tag-suffix: "-metal-darwin-arm64-local-store" build-type: "metal" lang: "go" + - backend: "llama-cpp-quantization" + tag-suffix: "-metal-darwin-arm64-llama-cpp-quantization" + build-type: "mps" with: backend: ${{ matrix.backend }} build-type: ${{ matrix.build-type }} diff --git a/.github/workflows/test-extra.yml b/.github/workflows/test-extra.yml index a4bb15f11..e5f3f8bca 100644 --- a/.github/workflows/test-extra.yml +++ b/.github/workflows/test-extra.yml @@ -383,6 +383,33 @@ jobs: run: | make --jobs=5 --output-sync=target -C backend/python/voxcpm make --jobs=5 --output-sync=target -C backend/python/voxcpm test + tests-llama-cpp-quantization: + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Clone + uses: actions/checkout@v6 + with: + submodules: true + - name: Dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake curl git python3-pip + # Install UV + curl -LsSf https://astral.sh/uv/install.sh | sh + pip install --user --no-cache-dir grpcio-tools==1.64.1 + - name: Build llama-quantize from llama.cpp + run: | + git clone --depth 1 https://github.com/ggml-org/llama.cpp.git /tmp/llama.cpp + cmake -B /tmp/llama.cpp/build -S /tmp/llama.cpp -DGGML_NATIVE=OFF + cmake --build /tmp/llama.cpp/build --target llama-quantize -j$(nproc) + sudo cp /tmp/llama.cpp/build/bin/llama-quantize /usr/local/bin/ + - name: Install backend + run: | + make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization + - name: Test llama-cpp-quantization + run: | + make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization test tests-acestep-cpp: runs-on: ubuntu-latest steps: diff --git a/Makefile b/Makefile index e3c02e9c2..459aa9208 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Disable parallel execution for backend builds -.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl +.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization GOCMD=go GOTEST=$(GOCMD) test @@ -575,6 +575,7 @@ BACKEND_WHISPERX = whisperx|python|.|false|true BACKEND_ACE_STEP = ace-step|python|.|false|true BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true BACKEND_TRL = trl|python|.|false|true +BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true # Helper function to build docker image for a backend # Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG) @@ -633,12 +634,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP))) $(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP))) $(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED))) $(eval $(call generate-docker-build-target,$(BACKEND_TRL))) +$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION))) # Pattern rule for docker-save targets docker-save-%: backend-images docker save local-ai-backend:$* -o backend-images/$*.tar -docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl +docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization ######################################################## ### Mock Backend for E2E Tests diff --git a/backend/backend.proto b/backend/backend.proto index 91497c523..f89c18571 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -46,6 +46,11 @@ service Backend { rpc StopFineTune(FineTuneStopRequest) returns (Result) {} rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {} rpc ExportModel(ExportModelRequest) returns (Result) {} + + // Quantization RPCs + rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {} + rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {} + rpc StopQuantization(QuantizationStopRequest) returns (Result) {} } // Define the empty request @@ -637,3 +642,36 @@ message ExportModelRequest { string model = 5; // base model name (for merge operations) map extra_options = 6; } + +// Quantization messages + +message QuantizationRequest { + string model = 1; // HF model name or local path + string quantization_type = 2; // q4_k_m, q5_k_m, q8_0, f16, etc. + string output_dir = 3; // where to write output files + string job_id = 4; // client-assigned job ID + map extra_options = 5; // hf_token, custom flags, etc. +} + +message QuantizationJobResult { + string job_id = 1; + bool success = 2; + string message = 3; +} + +message QuantizationProgressRequest { + string job_id = 1; +} + +message QuantizationProgressUpdate { + string job_id = 1; + float progress_percent = 2; + string status = 3; // queued, downloading, converting, quantizing, completed, failed, stopped + string message = 4; + string output_file = 5; // set when completed — path to the output GGUF file + map extra_metrics = 6; // e.g. file_size_mb, compression_ratio +} + +message QuantizationStopRequest { + string job_id = 1; +} diff --git a/backend/index.yaml b/backend/index.yaml index 8a8232747..d94cb70be 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -3081,3 +3081,31 @@ uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl" mirrors: - localai/localai-backends:master-cublas-cuda13-trl +## llama.cpp quantization backend +- &llama-cpp-quantization + name: "llama-cpp-quantization" + alias: "llama-cpp-quantization" + license: mit + icon: https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png + description: | + Model quantization backend using llama.cpp. Downloads HuggingFace models, converts them to GGUF format, + and quantizes them to various formats (q4_k_m, q5_k_m, q8_0, f16, etc.). + urls: + - https://github.com/ggml-org/llama.cpp + tags: + - quantization + - GGUF + - CPU + capabilities: + default: "cpu-llama-cpp-quantization" + metal: "metal-darwin-arm64-llama-cpp-quantization" +- !!merge <<: *llama-cpp-quantization + name: "cpu-llama-cpp-quantization" + uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp-quantization" + mirrors: + - localai/localai-backends:latest-cpu-llama-cpp-quantization +- !!merge <<: *llama-cpp-quantization + name: "metal-darwin-arm64-llama-cpp-quantization" + uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp-quantization" + mirrors: + - localai/localai-backends:latest-metal-darwin-arm64-llama-cpp-quantization diff --git a/backend/python/llama-cpp-quantization/Makefile b/backend/python/llama-cpp-quantization/Makefile new file mode 100644 index 000000000..3e26cf9ba --- /dev/null +++ b/backend/python/llama-cpp-quantization/Makefile @@ -0,0 +1,26 @@ +# Version of llama.cpp to fetch convert_hf_to_gguf.py from +LLAMA_CPP_CONVERT_VERSION ?= master + +.PHONY: llama-cpp-quantization +llama-cpp-quantization: + LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh + +.PHONY: run +run: llama-cpp-quantization + @echo "Running llama-cpp-quantization..." + bash run.sh + @echo "llama-cpp-quantization run." + +.PHONY: test +test: llama-cpp-quantization + @echo "Testing llama-cpp-quantization..." + bash test.sh + @echo "llama-cpp-quantization tested." + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ diff --git a/backend/python/llama-cpp-quantization/backend.py b/backend/python/llama-cpp-quantization/backend.py new file mode 100644 index 000000000..359133d37 --- /dev/null +++ b/backend/python/llama-cpp-quantization/backend.py @@ -0,0 +1,420 @@ +#!/usr/bin/env python3 +""" +llama.cpp quantization backend for LocalAI. + +Downloads HuggingFace models, converts them to GGUF format using +convert_hf_to_gguf.py, and quantizes using llama-quantize. +""" +import argparse +import os +import queue +import re +import signal +import subprocess +import sys +import threading +import time +from concurrent import futures + +import grpc +import backend_pb2 +import backend_pb2_grpc + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4')) + + +class ActiveJob: + """Tracks a running quantization job.""" + def __init__(self, job_id): + self.job_id = job_id + self.progress_queue = queue.Queue() + self.stop_event = threading.Event() + self.thread = None + self.process = None # subprocess handle for killing + + +class BackendServicer(backend_pb2_grpc.BackendServicer): + def __init__(self): + self.jobs = {} # job_id -> ActiveJob + + def Health(self, request, context): + return backend_pb2.Reply(message=b"OK") + + def LoadModel(self, request, context): + """Accept LoadModel — actual work happens in StartQuantization.""" + return backend_pb2.Result(success=True, message="OK") + + def StartQuantization(self, request, context): + job_id = request.job_id + if job_id in self.jobs: + return backend_pb2.QuantizationJobResult( + job_id=job_id, + success=False, + message=f"Job {job_id} already exists", + ) + + job = ActiveJob(job_id) + self.jobs[job_id] = job + + job.thread = threading.Thread( + target=self._do_quantization, + args=(job, request), + daemon=True, + ) + job.thread.start() + + return backend_pb2.QuantizationJobResult( + job_id=job_id, + success=True, + message="Quantization job started", + ) + + def _send_progress(self, job, status, message, progress_percent=0.0, output_file="", extra_metrics=None): + update = backend_pb2.QuantizationProgressUpdate( + job_id=job.job_id, + progress_percent=progress_percent, + status=status, + message=message, + output_file=output_file, + extra_metrics=extra_metrics or {}, + ) + job.progress_queue.put(update) + + def _do_quantization(self, job, request): + try: + model = request.model + quant_type = request.quantization_type or "q4_k_m" + output_dir = request.output_dir + extra_options = dict(request.extra_options) if request.extra_options else {} + + os.makedirs(output_dir, exist_ok=True) + + if job.stop_event.is_set(): + self._send_progress(job, "stopped", "Job stopped before starting") + return + + # Step 1: Download / resolve model + self._send_progress(job, "downloading", f"Resolving model: {model}", progress_percent=0.0) + + model_path = self._resolve_model(job, model, output_dir, extra_options) + if model_path is None: + return # error already sent + + if job.stop_event.is_set(): + self._send_progress(job, "stopped", "Job stopped during download") + return + + # Step 2: Convert to f16 GGUF + self._send_progress(job, "converting", "Converting model to GGUF (f16)...", progress_percent=30.0) + + f16_gguf_path = os.path.join(output_dir, "model-f16.gguf") + if not self._convert_to_gguf(job, model_path, f16_gguf_path, extra_options): + return # error already sent + + if job.stop_event.is_set(): + self._send_progress(job, "stopped", "Job stopped during conversion") + return + + # Step 3: Quantize + # If the user requested f16, skip quantization — the f16 GGUF is the final output + if quant_type.lower() in ("f16", "fp16"): + output_file = f16_gguf_path + self._send_progress( + job, "completed", + f"Model converted to f16 GGUF: {output_file}", + progress_percent=100.0, + output_file=output_file, + extra_metrics=self._file_metrics(output_file), + ) + return + + output_file = os.path.join(output_dir, f"model-{quant_type}.gguf") + self._send_progress(job, "quantizing", f"Quantizing to {quant_type}...", progress_percent=50.0) + + if not self._quantize(job, f16_gguf_path, output_file, quant_type): + return # error already sent + + # Clean up f16 intermediate file to save disk space + try: + os.remove(f16_gguf_path) + except OSError: + pass + + self._send_progress( + job, "completed", + f"Quantization complete: {quant_type}", + progress_percent=100.0, + output_file=output_file, + extra_metrics=self._file_metrics(output_file), + ) + + except Exception as e: + self._send_progress(job, "failed", f"Quantization failed: {str(e)}") + + def _resolve_model(self, job, model, output_dir, extra_options): + """Download model from HuggingFace or return local path.""" + # If it's a local path that exists, use it directly + if os.path.isdir(model): + return model + + # If it looks like a GGUF file path, use it directly + if os.path.isfile(model) and model.endswith(".gguf"): + return model + + # Download from HuggingFace + try: + from huggingface_hub import snapshot_download + + hf_token = extra_options.get("hf_token") or os.environ.get("HF_TOKEN") + cache_dir = os.path.join(output_dir, "hf_cache") + + self._send_progress(job, "downloading", f"Downloading {model} from HuggingFace...", progress_percent=5.0) + + local_path = snapshot_download( + repo_id=model, + cache_dir=cache_dir, + token=hf_token, + ignore_patterns=["*.md", "*.txt", "LICENSE*", ".gitattributes"], + ) + + self._send_progress(job, "downloading", f"Downloaded {model}", progress_percent=25.0) + return local_path + + except Exception as e: + error_msg = str(e) + if "gated" in error_msg.lower() or "access" in error_msg.lower(): + self._send_progress( + job, "failed", + f"Access denied for {model}. This model may be gated — " + f"please accept the license at https://huggingface.co/{model} " + f"and provide your HF token in extra_options.", + ) + else: + self._send_progress(job, "failed", f"Failed to download model: {error_msg}") + return None + + def _convert_to_gguf(self, job, model_path, output_path, extra_options): + """Convert HF model to f16 GGUF using convert_hf_to_gguf.py.""" + # If the model_path is already a GGUF file, just use it as-is + if isinstance(model_path, str) and model_path.endswith(".gguf"): + # Copy or symlink the GGUF file + import shutil + shutil.copy2(model_path, output_path) + return True + + # Find convert_hf_to_gguf.py + convert_script = self._find_convert_script() + if convert_script is None: + self._send_progress(job, "failed", "convert_hf_to_gguf.py not found. Install it via the backend's install.sh.") + return False + + cmd = [ + sys.executable, convert_script, + model_path, + "--outfile", output_path, + "--outtype", "f16", + ] + + self._send_progress(job, "converting", "Running convert_hf_to_gguf.py...", progress_percent=35.0) + + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + job.process = process + + for line in process.stdout: + line = line.strip() + if line: + self._send_progress(job, "converting", line, progress_percent=40.0) + if job.stop_event.is_set(): + process.kill() + self._send_progress(job, "stopped", "Job stopped during conversion") + return False + + process.wait() + job.process = None + + if process.returncode != 0: + self._send_progress(job, "failed", f"convert_hf_to_gguf.py failed with exit code {process.returncode}") + return False + + return True + + except Exception as e: + self._send_progress(job, "failed", f"Conversion failed: {str(e)}") + return False + + def _quantize(self, job, input_path, output_path, quant_type): + """Quantize a GGUF file using llama-quantize.""" + quantize_bin = self._find_quantize_binary() + if quantize_bin is None: + self._send_progress(job, "failed", "llama-quantize binary not found. Ensure it is installed and in PATH.") + return False + + cmd = [quantize_bin, input_path, output_path, quant_type] + + self._send_progress(job, "quantizing", f"Running llama-quantize ({quant_type})...", progress_percent=55.0) + + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + job.process = process + + for line in process.stdout: + line = line.strip() + if line: + # Try to parse progress from llama-quantize output + progress = self._parse_quantize_progress(line) + pct = 55.0 + (progress * 0.40) if progress else 60.0 + self._send_progress(job, "quantizing", line, progress_percent=pct) + if job.stop_event.is_set(): + process.kill() + self._send_progress(job, "stopped", "Job stopped during quantization") + return False + + process.wait() + job.process = None + + if process.returncode != 0: + self._send_progress(job, "failed", f"llama-quantize failed with exit code {process.returncode}") + return False + + return True + + except Exception as e: + self._send_progress(job, "failed", f"Quantization failed: {str(e)}") + return False + + def _parse_quantize_progress(self, line): + """Try to parse a progress percentage from llama-quantize output.""" + # llama-quantize typically outputs lines like: + # [ 123/ 1234] quantizing blk.0.attn_k.weight ... + match = re.search(r'\[\s*(\d+)\s*/\s*(\d+)\]', line) + if match: + current = int(match.group(1)) + total = int(match.group(2)) + if total > 0: + return current / total + return None + + def _find_convert_script(self): + """Find convert_hf_to_gguf.py in known locations.""" + candidates = [ + # Same directory as this backend + os.path.join(os.path.dirname(__file__), "convert_hf_to_gguf.py"), + # Installed via install.sh + os.path.join(os.path.dirname(os.path.abspath(__file__)), "convert_hf_to_gguf.py"), + ] + + # Also check if it's on PATH + import shutil + path_script = shutil.which("convert_hf_to_gguf.py") + if path_script: + candidates.append(path_script) + + for candidate in candidates: + if os.path.isfile(candidate): + return candidate + return None + + def _find_quantize_binary(self): + """Find llama-quantize binary.""" + import shutil + + # Check common names on PATH + for name in ["llama-quantize", "quantize"]: + path = shutil.which(name) + if path: + return path + + # Check in the backend directory (built by install.sh) + backend_dir = os.path.dirname(os.path.abspath(__file__)) + for name in ["llama-quantize", "quantize"]: + candidate = os.path.join(backend_dir, name) + if os.path.isfile(candidate) and os.access(candidate, os.X_OK): + return candidate + + return None + + def _file_metrics(self, filepath): + """Return file size metrics.""" + try: + size_bytes = os.path.getsize(filepath) + return {"file_size_mb": size_bytes / (1024 * 1024)} + except OSError: + return {} + + def QuantizationProgress(self, request, context): + job_id = request.job_id + job = self.jobs.get(job_id) + if job is None: + context.abort(grpc.StatusCode.NOT_FOUND, f"Job {job_id} not found") + return + + while True: + try: + update = job.progress_queue.get(timeout=1.0) + yield update + # If this is a terminal status, stop streaming + if update.status in ("completed", "failed", "stopped"): + break + except queue.Empty: + # Check if the thread is still alive + if job.thread and not job.thread.is_alive(): + # Thread finished but no terminal update — drain queue + while not job.progress_queue.empty(): + update = job.progress_queue.get_nowait() + yield update + break + # Check if client disconnected + if context.is_active() is False: + break + + def StopQuantization(self, request, context): + job_id = request.job_id + job = self.jobs.get(job_id) + if job is None: + return backend_pb2.Result(success=False, message=f"Job {job_id} not found") + + job.stop_event.set() + if job.process: + try: + job.process.kill() + except OSError: + pass + + return backend_pb2.Result(success=True, message="Stop signal sent") + + +def serve(address): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print(f"Quantization backend listening on {address}", file=sys.stderr, flush=True) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="llama.cpp quantization gRPC backend") + parser.add_argument("--addr", default="localhost:50051", help="gRPC server address") + args = parser.parse_args() + + signal.signal(signal.SIGINT, lambda sig, frame: sys.exit(0)) + serve(args.addr) diff --git a/backend/python/llama-cpp-quantization/install.sh b/backend/python/llama-cpp-quantization/install.sh new file mode 100755 index 000000000..05ac24f70 --- /dev/null +++ b/backend/python/llama-cpp-quantization/install.sh @@ -0,0 +1,58 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +EXTRA_PIP_INSTALL_FLAGS+=" --upgrade " +installRequirements + +# Fetch convert_hf_to_gguf.py from llama.cpp +LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}" +CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py" +if [ ! -f "${CONVERT_SCRIPT}" ]; then + echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..." + curl -L --fail --retry 3 \ + "https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \ + -o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py." +fi + +# Install gguf package from the same llama.cpp commit to keep them in sync +GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py" +echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..." +if [ "x${USE_PIP:-}" == "xtrue" ]; then + pip install "${GGUF_PIP_SPEC}" || { + echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..." + pip install "gguf>=0.16.0" + } +else + uv pip install "${GGUF_PIP_SPEC}" || { + echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..." + uv pip install "gguf>=0.16.0" + } +fi + +# Build llama-quantize from llama.cpp if not already present +QUANTIZE_BIN="${EDIR}/llama-quantize" +if [ ! -x "${QUANTIZE_BIN}" ] && ! command -v llama-quantize &>/dev/null; then + if command -v cmake &>/dev/null; then + echo "Building llama-quantize from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..." + LLAMA_CPP_SRC="${EDIR}/llama.cpp" + if [ ! -d "${LLAMA_CPP_SRC}" ]; then + git clone --depth 1 --branch "${LLAMA_CPP_CONVERT_VERSION}" \ + https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}" 2>/dev/null || \ + git clone --depth 1 https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}" + fi + cmake -B "${LLAMA_CPP_SRC}/build" -S "${LLAMA_CPP_SRC}" -DGGML_NATIVE=OFF -DBUILD_SHARED_LIBS=OFF + cmake --build "${LLAMA_CPP_SRC}/build" --target llama-quantize -j"$(nproc 2>/dev/null || echo 2)" + cp "${LLAMA_CPP_SRC}/build/bin/llama-quantize" "${QUANTIZE_BIN}" + chmod +x "${QUANTIZE_BIN}" + echo "Built llama-quantize at ${QUANTIZE_BIN}" + else + echo "Warning: cmake not found — llama-quantize will not be available. Install cmake or provide llama-quantize on PATH." + fi +fi diff --git a/backend/python/llama-cpp-quantization/requirements-cpu.txt b/backend/python/llama-cpp-quantization/requirements-cpu.txt new file mode 100644 index 000000000..4faf5a46e --- /dev/null +++ b/backend/python/llama-cpp-quantization/requirements-cpu.txt @@ -0,0 +1,5 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.10.0 +transformers>=4.56.2 +huggingface-hub>=1.3.0 +sentencepiece diff --git a/backend/python/llama-cpp-quantization/requirements-mps.txt b/backend/python/llama-cpp-quantization/requirements-mps.txt new file mode 100644 index 000000000..1e8d7b155 --- /dev/null +++ b/backend/python/llama-cpp-quantization/requirements-mps.txt @@ -0,0 +1,4 @@ +torch==2.10.0 +transformers>=4.56.2 +huggingface-hub>=1.3.0 +sentencepiece diff --git a/backend/python/llama-cpp-quantization/requirements.txt b/backend/python/llama-cpp-quantization/requirements.txt new file mode 100644 index 000000000..0834a8fcd --- /dev/null +++ b/backend/python/llama-cpp-quantization/requirements.txt @@ -0,0 +1,3 @@ +grpcio==1.78.1 +protobuf +certifi diff --git a/backend/python/llama-cpp-quantization/run.sh b/backend/python/llama-cpp-quantization/run.sh new file mode 100755 index 000000000..bd17c6e1d --- /dev/null +++ b/backend/python/llama-cpp-quantization/run.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ diff --git a/backend/python/llama-cpp-quantization/test.py b/backend/python/llama-cpp-quantization/test.py new file mode 100644 index 000000000..5cb3a6519 --- /dev/null +++ b/backend/python/llama-cpp-quantization/test.py @@ -0,0 +1,99 @@ +""" +Test script for the llama-cpp-quantization gRPC backend. + +Downloads a small model (functiongemma-270m-it), converts it to GGUF, +and quantizes it to q4_k_m. +""" +import os +import shutil +import subprocess +import tempfile +import time +import unittest + +import grpc +import backend_pb2 +import backend_pb2_grpc + + +SERVER_ADDR = "localhost:50051" +# Small model for CI testing (~540MB) +TEST_MODEL = "unsloth/functiongemma-270m-it" + + +class TestQuantizationBackend(unittest.TestCase): + """Tests for the llama-cpp-quantization gRPC service.""" + + @classmethod + def setUpClass(cls): + cls.service = subprocess.Popen( + ["python3", "backend.py", "--addr", SERVER_ADDR] + ) + time.sleep(5) + cls.output_dir = tempfile.mkdtemp(prefix="quantize-test-") + + @classmethod + def tearDownClass(cls): + cls.service.kill() + cls.service.wait() + # Clean up output directory + if os.path.isdir(cls.output_dir): + shutil.rmtree(cls.output_dir, ignore_errors=True) + + def _channel(self): + return grpc.insecure_channel(SERVER_ADDR) + + def test_01_health(self): + """Test that the server starts and responds to health checks.""" + with self._channel() as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b"OK") + + def test_02_quantize_small_model(self): + """Download, convert, and quantize functiongemma-270m-it to q4_k_m.""" + with self._channel() as channel: + stub = backend_pb2_grpc.BackendStub(channel) + + job_id = "test-quantize-001" + + # Start quantization + result = stub.StartQuantization( + backend_pb2.QuantizationRequest( + model=TEST_MODEL, + quantization_type="q4_k_m", + output_dir=self.output_dir, + job_id=job_id, + ) + ) + self.assertTrue(result.success, f"StartQuantization failed: {result.message}") + self.assertEqual(result.job_id, job_id) + + # Stream progress until completion + final_status = None + output_file = None + for update in stub.QuantizationProgress( + backend_pb2.QuantizationProgressRequest(job_id=job_id) + ): + print(f" [{update.status}] {update.progress_percent:.1f}% - {update.message}") + final_status = update.status + if update.output_file: + output_file = update.output_file + + self.assertEqual(final_status, "completed", f"Expected completed, got {final_status}") + self.assertIsNotNone(output_file, "No output_file in progress updates") + self.assertTrue(os.path.isfile(output_file), f"Output file not found: {output_file}") + + # Verify the output is a valid GGUF file (starts with "GGUF" magic) + with open(output_file, "rb") as f: + magic = f.read(4) + self.assertEqual(magic, b"GGUF", f"Output file does not have GGUF magic: {magic!r}") + + # Verify reasonable file size (q4_k_m of 270M model should be ~150-400MB) + size_mb = os.path.getsize(output_file) / (1024 * 1024) + print(f" Output file size: {size_mb:.1f} MB") + self.assertGreater(size_mb, 10, "Output file suspiciously small") + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/python/llama-cpp-quantization/test.sh b/backend/python/llama-cpp-quantization/test.sh new file mode 100755 index 000000000..eb59f2aaf --- /dev/null +++ b/backend/python/llama-cpp-quantization/test.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests diff --git a/core/cli/run.go b/core/cli/run.go index 000dd3366..c614e123d 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -121,9 +121,6 @@ type RunCMD struct { AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"` AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"` - // Fine-tuning - EnableFineTuning bool `env:"LOCALAI_ENABLE_FINETUNING" default:"false" help:"Enable fine-tuning support" group:"finetuning"` - // Authentication AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"` AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"` @@ -329,11 +326,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { opts = append(opts, config.WithAgentHubURL(r.AgentHubURL)) } - // Fine-tuning - if r.EnableFineTuning { - opts = append(opts, config.EnableFineTuning) - } - // Authentication authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != "" if authEnabled { diff --git a/core/config/application_config.go b/core/config/application_config.go index bb187be43..9c1be82d9 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -97,9 +97,6 @@ type ApplicationConfig struct { // Agent Pool (LocalAGI integration) AgentPool AgentPoolConfig - // Fine-tuning - FineTuning FineTuningConfig - // Authentication & Authorization Auth AuthConfig } @@ -145,11 +142,6 @@ type AgentPoolConfig struct { AgentHubURL string // default: "https://agenthub.localai.io" } -// FineTuningConfig holds configuration for fine-tuning support. -type FineTuningConfig struct { - Enabled bool -} - type AppOption func(*ApplicationConfig) func NewApplicationConfig(o ...AppOption) *ApplicationConfig { @@ -741,12 +733,6 @@ func WithAgentHubURL(url string) AppOption { } } -// Fine-tuning options - -var EnableFineTuning = func(o *ApplicationConfig) { - o.FineTuning.Enabled = true -} - // Auth options func WithAuthEnabled(enabled bool) AppOption { diff --git a/core/http/app.go b/core/http/app.go index 587fddfe1..94f36c89d 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -304,15 +304,22 @@ func API(application *application.Application) (*echo.Echo, error) { routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw, mcpMw) routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw) // Fine-tuning routes - if application.ApplicationConfig().FineTuning.Enabled { - fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning) - ftService := services.NewFineTuneService( - application.ApplicationConfig(), - application.ModelLoader(), - application.ModelConfigLoader(), - ) - routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw) - } + fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning) + ftService := services.NewFineTuneService( + application.ApplicationConfig(), + application.ModelLoader(), + application.ModelConfigLoader(), + ) + routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw) + + // Quantization routes + quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization) + qService := services.NewQuantizationService( + application.ApplicationConfig(), + application.ModelLoader(), + application.ModelConfigLoader(), + ) + routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw) routes.RegisterOpenAIRoutes(e, requestExtractor, application) routes.RegisterAnthropicRoutes(e, requestExtractor, application) diff --git a/core/http/auth/features.go b/core/http/auth/features.go index 1d7ff4f61..d029bdc22 100644 --- a/core/http/auth/features.go +++ b/core/http/auth/features.go @@ -97,6 +97,16 @@ var RouteFeatureRegistry = []RouteFeature{ {"POST", "/api/fine-tuning/jobs/:id/export", FeatureFineTuning}, {"GET", "/api/fine-tuning/jobs/:id/download", FeatureFineTuning}, {"POST", "/api/fine-tuning/datasets", FeatureFineTuning}, + + // Quantization + {"POST", "/api/quantization/jobs", FeatureQuantization}, + {"GET", "/api/quantization/jobs", FeatureQuantization}, + {"GET", "/api/quantization/jobs/:id", FeatureQuantization}, + {"POST", "/api/quantization/jobs/:id/stop", FeatureQuantization}, + {"DELETE", "/api/quantization/jobs/:id", FeatureQuantization}, + {"GET", "/api/quantization/jobs/:id/progress", FeatureQuantization}, + {"POST", "/api/quantization/jobs/:id/import", FeatureQuantization}, + {"GET", "/api/quantization/jobs/:id/download", FeatureQuantization}, } // FeatureMeta describes a feature for the admin API/UI. @@ -120,6 +130,7 @@ func AgentFeatureMetas() []FeatureMeta { func GeneralFeatureMetas() []FeatureMeta { return []FeatureMeta{ {FeatureFineTuning, "Fine-Tuning", false}, + {FeatureQuantization, "Quantization", false}, } } diff --git a/core/http/auth/permissions.go b/core/http/auth/permissions.go index 1fd9bbc8b..63bce7d21 100644 --- a/core/http/auth/permissions.go +++ b/core/http/auth/permissions.go @@ -33,7 +33,8 @@ const ( FeatureMCPJobs = "mcp_jobs" // General features (default OFF for new users) - FeatureFineTuning = "fine_tuning" + FeatureFineTuning = "fine_tuning" + FeatureQuantization = "quantization" // API features (default ON for new users) FeatureChat = "chat" @@ -56,7 +57,7 @@ const ( var AgentFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs} // GeneralFeatures lists general features (default OFF). -var GeneralFeatures = []string{FeatureFineTuning} +var GeneralFeatures = []string{FeatureFineTuning, FeatureQuantization} // APIFeatures lists API endpoint features (default ON). var APIFeatures = []string{ diff --git a/core/http/endpoints/localai/quantization.go b/core/http/endpoints/localai/quantization.go new file mode 100644 index 000000000..e3bf011da --- /dev/null +++ b/core/http/endpoints/localai/quantization.go @@ -0,0 +1,243 @@ +package localai + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services" +) + +// StartQuantizationJobEndpoint starts a new quantization job. +func StartQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + + var req schema.QuantizationJobRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "Invalid request: " + err.Error(), + }) + } + + if req.Model == "" { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "model is required", + }) + } + + resp, err := qService.StartJob(c.Request().Context(), userID, req) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusCreated, resp) + } +} + +// ListQuantizationJobsEndpoint lists quantization jobs for the current user. +func ListQuantizationJobsEndpoint(qService *services.QuantizationService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobs := qService.ListJobs(userID) + if jobs == nil { + jobs = []*schema.QuantizationJob{} + } + return c.JSON(http.StatusOK, jobs) + } +} + +// GetQuantizationJobEndpoint gets a specific quantization job. +func GetQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + job, err := qService.GetJob(userID, jobID) + if err != nil { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, job) + } +} + +// StopQuantizationJobEndpoint stops a running quantization job. +func StopQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + err := qService.StopJob(c.Request().Context(), userID, jobID) + if err != nil { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]string{ + "status": "stopped", + "message": "Quantization job stopped", + }) + } +} + +// DeleteQuantizationJobEndpoint deletes a quantization job and its data. +func DeleteQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + err := qService.DeleteJob(userID, jobID) + if err != nil { + status := http.StatusInternalServerError + if strings.Contains(err.Error(), "not found") { + status = http.StatusNotFound + } else if strings.Contains(err.Error(), "cannot delete") { + status = http.StatusConflict + } + return c.JSON(status, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]string{ + "status": "deleted", + "message": "Quantization job deleted", + }) + } +} + +// QuantizationProgressEndpoint streams progress updates via SSE. +func QuantizationProgressEndpoint(qService *services.QuantizationService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + // Set SSE headers + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + c.Response().WriteHeader(http.StatusOK) + + err := qService.StreamProgress(c.Request().Context(), userID, jobID, func(event *schema.QuantizationProgressEvent) { + data, err := json.Marshal(event) + if err != nil { + return + } + fmt.Fprintf(c.Response(), "data: %s\n\n", data) + c.Response().Flush() + }) + if err != nil { + // If headers already sent, we can't send a JSON error + fmt.Fprintf(c.Response(), "data: {\"status\":\"error\",\"message\":%q}\n\n", err.Error()) + c.Response().Flush() + } + + return nil + } +} + +// ImportQuantizedModelEndpoint imports a quantized model into LocalAI. +func ImportQuantizedModelEndpoint(qService *services.QuantizationService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + var req schema.QuantizationImportRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "Invalid request: " + err.Error(), + }) + } + + modelName, err := qService.ImportModel(c.Request().Context(), userID, jobID, req) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusAccepted, map[string]string{ + "status": "importing", + "message": "Import started for model '" + modelName + "'", + "model_name": modelName, + }) + } +} + +// DownloadQuantizedModelEndpoint streams the quantized model file. +func DownloadQuantizedModelEndpoint(qService *services.QuantizationService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + outputPath, downloadName, err := qService.GetOutputPath(userID, jobID) + if err != nil { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": err.Error(), + }) + } + + return c.Attachment(outputPath, downloadName) + } +} + +// ListQuantizationBackendsEndpoint returns installed backends tagged with "quantization". +func ListQuantizationBackendsEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": "failed to list backends: " + err.Error(), + }) + } + + type backendInfo struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Tags []string `json:"tags,omitempty"` + } + + var result []backendInfo + for _, b := range backends { + if !b.Installed { + continue + } + hasTag := false + for _, t := range b.Tags { + if strings.EqualFold(t, "quantization") { + hasTag = true + break + } + } + if !hasTag { + continue + } + name := b.Name + if b.Alias != "" { + name = b.Alias + } + result = append(result, backendInfo{ + Name: name, + Description: b.Description, + Tags: b.Tags, + }) + } + + if result == nil { + result = []backendInfo{} + } + + return c.JSON(http.StatusOK, result) + } +} diff --git a/core/http/react-ui/src/components/Sidebar.jsx b/core/http/react-ui/src/components/Sidebar.jsx index 4b0ebc229..f026d9c5c 100644 --- a/core/http/react-ui/src/components/Sidebar.jsx +++ b/core/http/react-ui/src/components/Sidebar.jsx @@ -20,7 +20,8 @@ const sections = [ id: 'tools', title: 'Tools', items: [ - { path: '/app/fine-tune', icon: 'fas fa-graduation-cap', label: 'Fine-Tune', feature: 'fine_tuning' }, + { path: '/app/fine-tune', icon: 'fas fa-graduation-cap', label: 'Fine-Tune (Experimental)', feature: 'fine_tuning' }, + { path: '/app/quantize', icon: 'fas fa-compress', label: 'Quantize (Experimental)', feature: 'quantization' }, ], }, { diff --git a/core/http/react-ui/src/pages/FineTune.jsx b/core/http/react-ui/src/pages/FineTune.jsx index 606848754..fa153efb1 100644 --- a/core/http/react-ui/src/pages/FineTune.jsx +++ b/core/http/react-ui/src/pages/FineTune.jsx @@ -1042,7 +1042,7 @@ export default function FineTune() {
-

Fine-Tuning

+

Fine-Tuning Experimental

Create and manage fine-tuning jobs

diff --git a/core/http/react-ui/src/pages/Quantize.jsx b/core/http/react-ui/src/pages/Quantize.jsx new file mode 100644 index 000000000..bdfc056bc --- /dev/null +++ b/core/http/react-ui/src/pages/Quantize.jsx @@ -0,0 +1,485 @@ +import { useState, useEffect, useCallback, useRef } from 'react' +import { quantizationApi } from '../utils/api' + +const QUANT_PRESETS = [ + 'q2_k', 'q3_k_s', 'q3_k_m', 'q3_k_l', + 'q4_0', 'q4_k_s', 'q4_k_m', + 'q5_0', 'q5_k_s', 'q5_k_m', + 'q6_k', 'q8_0', 'f16', +] +const DEFAULT_QUANT = 'q4_k_m' +const FALLBACK_BACKENDS = ['llama-cpp-quantization'] + +const statusBadgeClass = { + queued: '', downloading: 'badge-warning', converting: 'badge-warning', + quantizing: 'badge-info', completed: 'badge-success', + failed: 'badge-error', stopped: '', +} + +// ── Reusable sub-components ────────────────────────────────────── + +function FormSection({ icon, title, children }) { + return ( +
+

+ {icon && } {title} +

+ {children} +
+ ) +} + +function ProgressMonitor({ job, onClose }) { + const [events, setEvents] = useState([]) + const [latestEvent, setLatestEvent] = useState(null) + const esRef = useRef(null) + + useEffect(() => { + if (!job) return + const terminal = ['completed', 'failed', 'stopped'] + if (terminal.includes(job.status)) return + + const es = new EventSource(quantizationApi.progressUrl(job.id)) + esRef.current = es + es.onmessage = (e) => { + try { + const data = JSON.parse(e.data) + setLatestEvent(data) + setEvents(prev => [...prev.slice(-100), data]) + if (terminal.includes(data.status)) { + es.close() + } + } catch { /* ignore parse errors */ } + } + es.onerror = () => { es.close() } + return () => { es.close() } + }, [job?.id]) + + if (!job) return null + + const progress = latestEvent?.progress_percent ?? 0 + const status = latestEvent?.status ?? job.status + const message = latestEvent?.message ?? job.message ?? '' + + return ( +
+
+

+ + Progress: {job.model} +

+ +
+ +
+ {status} + {message} +
+ +
+
+ {progress > 8 ? `${progress.toFixed(1)}%` : ''} +
+
+ + {/* Log tail */} +
+ {events.slice(-20).map((ev, i) => ( +
+ [{ev.status}] {ev.message} +
+ ))} +
+
+ ) +} + +function ImportPanel({ job, onRefresh }) { + const [modelName, setModelName] = useState('') + const [importing, setImporting] = useState(false) + const [error, setError] = useState('') + const pollRef = useRef(null) + + // Poll for import status + useEffect(() => { + if (job?.import_status !== 'importing') return + pollRef.current = setInterval(async () => { + try { + await onRefresh() + } catch { /* ignore */ } + }, 3000) + return () => clearInterval(pollRef.current) + }, [job?.import_status, onRefresh]) + + if (!job || job.status !== 'completed') return null + + const handleImport = async () => { + setImporting(true) + setError('') + try { + await quantizationApi.importModel(job.id, { name: modelName || undefined }) + await onRefresh() + } catch (e) { + setError(e.message || 'Import failed') + } finally { + setImporting(false) + } + } + + return ( +
+

+ + Output +

+ + {error && ( +
{error}
+ )} + +
+ {/* Download */} + + + Download GGUF + + + {/* Import */} + {job.import_status === 'completed' ? ( + + + Chat with {job.import_model_name} + + ) : job.import_status === 'importing' ? ( + + ) : ( + <> + setModelName(e.target.value)} + style={{ maxWidth: '280px' }} + /> + + + )} +
+ + {job.import_status === 'failed' && ( +
+ Import failed: {job.import_message} +
+ )} +
+ ) +} + +// ── Main page ──────────────────────────────────────────────────── + +export default function Quantize() { + // Form state + const [model, setModel] = useState('') + const [quantType, setQuantType] = useState(DEFAULT_QUANT) + const [customQuantType, setCustomQuantType] = useState('') + const [useCustomQuant, setUseCustomQuant] = useState(false) + const [backend, setBackend] = useState('') + const [hfToken, setHfToken] = useState('') + const [backends, setBackends] = useState([]) + + // Jobs state + const [jobs, setJobs] = useState([]) + const [selectedJob, setSelectedJob] = useState(null) + const [error, setError] = useState('') + const [submitting, setSubmitting] = useState(false) + + // Load backends and jobs + const loadJobs = useCallback(async () => { + try { + const data = await quantizationApi.listJobs() + setJobs(data) + // Refresh selected job if it exists + if (selectedJob) { + const updated = data.find(j => j.id === selectedJob.id) + if (updated) setSelectedJob(updated) + } + } catch { /* ignore */ } + }, [selectedJob?.id]) + + useEffect(() => { + quantizationApi.listBackends().then(b => { + setBackends(b.length ? b : FALLBACK_BACKENDS.map(name => ({ name }))) + if (b.length) setBackend(b[0].name) + else setBackend(FALLBACK_BACKENDS[0]) + }).catch(() => { + setBackends(FALLBACK_BACKENDS.map(name => ({ name }))) + setBackend(FALLBACK_BACKENDS[0]) + }) + loadJobs() + const interval = setInterval(loadJobs, 10000) + return () => clearInterval(interval) + }, []) + + const handleSubmit = async (e) => { + e.preventDefault() + setError('') + setSubmitting(true) + try { + const req = { + model, + backend: backend || FALLBACK_BACKENDS[0], + quantization_type: useCustomQuant ? customQuantType : quantType, + } + if (hfToken) { + req.extra_options = { hf_token: hfToken } + } + + const resp = await quantizationApi.startJob(req) + setModel('') + await loadJobs() + + // Select the new job + const refreshed = await quantizationApi.getJob(resp.id) + setSelectedJob(refreshed) + } catch (err) { + setError(err.message || 'Failed to start job') + } finally { + setSubmitting(false) + } + } + + const handleStop = async (jobId) => { + try { + await quantizationApi.stopJob(jobId) + await loadJobs() + } catch (err) { + setError(err.message || 'Failed to stop job') + } + } + + const handleDelete = async (jobId) => { + try { + await quantizationApi.deleteJob(jobId) + if (selectedJob?.id === jobId) setSelectedJob(null) + await loadJobs() + } catch (err) { + setError(err.message || 'Failed to delete job') + } + } + + const effectiveQuantType = useCustomQuant ? customQuantType : quantType + + return ( +
+
+

+ + Model Quantization +

+ Experimental +
+ + {error && ( +
+
{error}
+
+ )} + + {/* ── New Job Form ── */} +
+
+ + setModel(e.target.value)} + required + style={{ width: '100%' }} + /> + + + +
+ + {useCustomQuant && ( + setCustomQuantType(e.target.value)} + required + style={{ maxWidth: '220px' }} + /> + )} +
+
+ + + + + + + setHfToken(e.target.value)} + style={{ maxWidth: '400px' }} + /> + + + +
+ + + {/* ── Selected Job Detail ── */} + {selectedJob && ( + <> + setSelectedJob(null)} + /> + { + const updated = await quantizationApi.getJob(selectedJob.id) + setSelectedJob(updated) + await loadJobs() + }} + /> + + )} + + {/* ── Jobs List ── */} + {jobs.length > 0 && ( +
+

+ + Jobs +

+
+ + + + + + + + + + + + {jobs.map(job => { + const isActive = ['queued', 'downloading', 'converting', 'quantizing'].includes(job.status) + const isSelected = selectedJob?.id === job.id + return ( + setSelectedJob(job)} + > + + + + + + + ) + })} + +
ModelQuantStatusCreatedActions
+ {job.model} + + {job.quantization_type} + + {job.status} + {job.import_status === 'completed' && ( + imported + )} + + {new Date(job.created_at).toLocaleString()} + +
e.stopPropagation()}> + {isActive && ( + + )} + {!isActive && ( + + )} +
+
+
+
+ )} +
+ ) +} diff --git a/core/http/react-ui/src/router.jsx b/core/http/react-ui/src/router.jsx index 023e36ff1..b0daac656 100644 --- a/core/http/react-ui/src/router.jsx +++ b/core/http/react-ui/src/router.jsx @@ -32,6 +32,7 @@ import BackendLogs from './pages/BackendLogs' import Explorer from './pages/Explorer' import Login from './pages/Login' import FineTune from './pages/FineTune' +import Quantize from './pages/Quantize' import Studio from './pages/Studio' import NotFound from './pages/NotFound' import Usage from './pages/Usage' @@ -95,6 +96,7 @@ const appChildren = [ { path: 'agent-jobs/tasks/:id/edit', element: }, { path: 'agent-jobs/jobs/:id', element: }, { path: 'fine-tune', element: }, + { path: 'quantize', element: }, { path: 'model-editor/:name', element: }, { path: 'pipeline-editor', element: }, { path: 'pipeline-editor/:name', element: }, diff --git a/core/http/react-ui/src/utils/api.js b/core/http/react-ui/src/utils/api.js index 6fe9bf169..7bffbb2d7 100644 --- a/core/http/react-ui/src/utils/api.js +++ b/core/http/react-ui/src/utils/api.js @@ -410,6 +410,19 @@ export const fineTuneApi = { downloadUrl: (id) => apiUrl(`/api/fine-tuning/jobs/${enc(id)}/download`), } +// Quantization API +export const quantizationApi = { + listBackends: () => fetchJSON('/api/quantization/backends'), + startJob: (data) => postJSON('/api/quantization/jobs', data), + listJobs: () => fetchJSON('/api/quantization/jobs'), + getJob: (id) => fetchJSON(`/api/quantization/jobs/${enc(id)}`), + stopJob: (id) => fetchJSON(`/api/quantization/jobs/${enc(id)}/stop`, { method: 'POST' }), + deleteJob: (id) => fetchJSON(`/api/quantization/jobs/${enc(id)}`, { method: 'DELETE' }), + importModel: (id, data) => postJSON(`/api/quantization/jobs/${enc(id)}/import`, data), + progressUrl: (id) => apiUrl(`/api/quantization/jobs/${enc(id)}/progress`), + downloadUrl: (id) => apiUrl(`/api/quantization/jobs/${enc(id)}/download`), +} + // File to base64 helper export function fileToBase64(file) { return new Promise((resolve, reject) => { diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 2a2925b00..c91986ab6 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -134,9 +134,10 @@ func RegisterLocalAIRoutes(router *echo.Echo, router.GET("/api/features", func(c echo.Context) error { return c.JSON(200, map[string]bool{ - "agents": appConfig.AgentPool.Enabled, - "mcp": !appConfig.DisableMCP, - "fine_tuning": appConfig.FineTuning.Enabled, + "agents": appConfig.AgentPool.Enabled, + "mcp": !appConfig.DisableMCP, + "fine_tuning": true, + "quantization": true, }) }) diff --git a/core/http/routes/quantization.go b/core/http/routes/quantization.go new file mode 100644 index 000000000..b3dc37b87 --- /dev/null +++ b/core/http/routes/quantization.go @@ -0,0 +1,40 @@ +package routes + +import ( + "net/http" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/localai" + "github.com/mudler/LocalAI/core/services" +) + +// RegisterQuantizationRoutes registers quantization API routes. +func RegisterQuantizationRoutes(e *echo.Echo, qService *services.QuantizationService, appConfig *config.ApplicationConfig, quantizationMw echo.MiddlewareFunc) { + if qService == nil { + return + } + + // Service readiness middleware + readyMw := func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if qService == nil { + return c.JSON(http.StatusServiceUnavailable, map[string]string{ + "error": "quantization service is not available", + }) + } + return next(c) + } + } + + q := e.Group("/api/quantization", readyMw, quantizationMw) + q.GET("/backends", localai.ListQuantizationBackendsEndpoint(appConfig)) + q.POST("/jobs", localai.StartQuantizationJobEndpoint(qService)) + q.GET("/jobs", localai.ListQuantizationJobsEndpoint(qService)) + q.GET("/jobs/:id", localai.GetQuantizationJobEndpoint(qService)) + q.POST("/jobs/:id/stop", localai.StopQuantizationJobEndpoint(qService)) + q.DELETE("/jobs/:id", localai.DeleteQuantizationJobEndpoint(qService)) + q.GET("/jobs/:id/progress", localai.QuantizationProgressEndpoint(qService)) + q.POST("/jobs/:id/import", localai.ImportQuantizedModelEndpoint(qService)) + q.GET("/jobs/:id/download", localai.DownloadQuantizedModelEndpoint(qService)) +} diff --git a/core/schema/quantization.go b/core/schema/quantization.go new file mode 100644 index 000000000..7720e7ead --- /dev/null +++ b/core/schema/quantization.go @@ -0,0 +1,55 @@ +package schema + +// QuantizationJobRequest is the REST API request to start a quantization job. +type QuantizationJobRequest struct { + Model string `json:"model"` // HF model name or local path + Backend string `json:"backend"` // "llama-cpp-quantization" + QuantizationType string `json:"quantization_type,omitempty"` // q4_k_m, q5_k_m, q8_0, f16, etc. + ExtraOptions map[string]string `json:"extra_options,omitempty"` +} + +// QuantizationJob represents a quantization job with its current state. +type QuantizationJob struct { + ID string `json:"id"` + UserID string `json:"user_id,omitempty"` + Model string `json:"model"` + Backend string `json:"backend"` + ModelID string `json:"model_id,omitempty"` + QuantizationType string `json:"quantization_type"` + Status string `json:"status"` // queued, downloading, converting, quantizing, completed, failed, stopped + Message string `json:"message,omitempty"` + OutputDir string `json:"output_dir"` + OutputFile string `json:"output_file,omitempty"` // path to final GGUF + ExtraOptions map[string]string `json:"extra_options,omitempty"` + CreatedAt string `json:"created_at"` + + // Import state (tracked separately from quantization status) + ImportStatus string `json:"import_status,omitempty"` // "", "importing", "completed", "failed" + ImportMessage string `json:"import_message,omitempty"` + ImportModelName string `json:"import_model_name,omitempty"` // registered model name after import + + // Full config for reuse + Config *QuantizationJobRequest `json:"config,omitempty"` +} + +// QuantizationJobResponse is the REST API response when creating a job. +type QuantizationJobResponse struct { + ID string `json:"id"` + Status string `json:"status"` + Message string `json:"message"` +} + +// QuantizationProgressEvent is an SSE event for quantization progress. +type QuantizationProgressEvent struct { + JobID string `json:"job_id"` + ProgressPercent float32 `json:"progress_percent"` + Status string `json:"status"` + Message string `json:"message,omitempty"` + OutputFile string `json:"output_file,omitempty"` + ExtraMetrics map[string]float32 `json:"extra_metrics,omitempty"` +} + +// QuantizationImportRequest is the REST API request to import a quantized model. +type QuantizationImportRequest struct { + Name string `json:"name,omitempty"` // model name for LocalAI (auto-generated if empty) +} diff --git a/core/services/quantization.go b/core/services/quantization.go new file mode 100644 index 000000000..bfe51917e --- /dev/null +++ b/core/services/quantization.go @@ -0,0 +1,574 @@ +package services + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery/importers" + "github.com/mudler/LocalAI/core/schema" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/utils" + "github.com/mudler/xlog" + "gopkg.in/yaml.v3" +) + +// QuantizationService manages quantization jobs and their lifecycle. +type QuantizationService struct { + appConfig *config.ApplicationConfig + modelLoader *model.ModelLoader + configLoader *config.ModelConfigLoader + + mu sync.Mutex + jobs map[string]*schema.QuantizationJob +} + +// NewQuantizationService creates a new QuantizationService. +func NewQuantizationService( + appConfig *config.ApplicationConfig, + modelLoader *model.ModelLoader, + configLoader *config.ModelConfigLoader, +) *QuantizationService { + s := &QuantizationService{ + appConfig: appConfig, + modelLoader: modelLoader, + configLoader: configLoader, + jobs: make(map[string]*schema.QuantizationJob), + } + s.loadAllJobs() + return s +} + +// quantizationBaseDir returns the base directory for quantization job data. +func (s *QuantizationService) quantizationBaseDir() string { + return filepath.Join(s.appConfig.DataPath, "quantization") +} + +// jobDir returns the directory for a specific job. +func (s *QuantizationService) jobDir(jobID string) string { + return filepath.Join(s.quantizationBaseDir(), jobID) +} + +// saveJobState persists a job's state to disk as state.json. +func (s *QuantizationService) saveJobState(job *schema.QuantizationJob) { + dir := s.jobDir(job.ID) + if err := os.MkdirAll(dir, 0750); err != nil { + xlog.Error("Failed to create quantization job directory", "job_id", job.ID, "error", err) + return + } + + data, err := json.MarshalIndent(job, "", " ") + if err != nil { + xlog.Error("Failed to marshal quantization job state", "job_id", job.ID, "error", err) + return + } + + statePath := filepath.Join(dir, "state.json") + if err := os.WriteFile(statePath, data, 0640); err != nil { + xlog.Error("Failed to write quantization job state", "job_id", job.ID, "error", err) + } +} + +// loadAllJobs scans the quantization directory for persisted jobs and loads them. +func (s *QuantizationService) loadAllJobs() { + baseDir := s.quantizationBaseDir() + entries, err := os.ReadDir(baseDir) + if err != nil { + // Directory doesn't exist yet — that's fine + return + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + statePath := filepath.Join(baseDir, entry.Name(), "state.json") + data, err := os.ReadFile(statePath) + if err != nil { + continue + } + + var job schema.QuantizationJob + if err := json.Unmarshal(data, &job); err != nil { + xlog.Warn("Failed to parse quantization job state", "path", statePath, "error", err) + continue + } + + // Jobs that were running when we shut down are now stale + if job.Status == "queued" || job.Status == "downloading" || job.Status == "converting" || job.Status == "quantizing" { + job.Status = "stopped" + job.Message = "Server restarted while job was running" + } + + // Imports that were in progress are now stale + if job.ImportStatus == "importing" { + job.ImportStatus = "failed" + job.ImportMessage = "Server restarted while import was running" + } + + s.jobs[job.ID] = &job + } + + if len(s.jobs) > 0 { + xlog.Info("Loaded persisted quantization jobs", "count", len(s.jobs)) + } +} + +// StartJob starts a new quantization job. +func (s *QuantizationService) StartJob(ctx context.Context, userID string, req schema.QuantizationJobRequest) (*schema.QuantizationJobResponse, error) { + s.mu.Lock() + defer s.mu.Unlock() + + jobID := uuid.New().String() + + backendName := req.Backend + if backendName == "" { + backendName = "llama-cpp-quantization" + } + + quantType := req.QuantizationType + if quantType == "" { + quantType = "q4_k_m" + } + + // Always use DataPath for output — not user-configurable + outputDir := filepath.Join(s.quantizationBaseDir(), jobID) + + // Build gRPC request + grpcReq := &pb.QuantizationRequest{ + Model: req.Model, + QuantizationType: quantType, + OutputDir: outputDir, + JobId: jobID, + ExtraOptions: req.ExtraOptions, + } + + // Load the quantization backend (per-job model ID so multiple jobs can run concurrently) + modelID := backendName + "-quantize-" + jobID + backendModel, err := s.modelLoader.Load( + model.WithBackendString(backendName), + model.WithModel(backendName), + model.WithModelID(modelID), + ) + if err != nil { + return nil, fmt.Errorf("failed to load backend %s: %w", backendName, err) + } + + // Start quantization via gRPC + result, err := backendModel.StartQuantization(ctx, grpcReq) + if err != nil { + return nil, fmt.Errorf("failed to start quantization: %w", err) + } + if !result.Success { + return nil, fmt.Errorf("quantization failed to start: %s", result.Message) + } + + // Track the job + job := &schema.QuantizationJob{ + ID: jobID, + UserID: userID, + Model: req.Model, + Backend: backendName, + ModelID: modelID, + QuantizationType: quantType, + Status: "queued", + OutputDir: outputDir, + ExtraOptions: req.ExtraOptions, + CreatedAt: time.Now().UTC().Format(time.RFC3339), + Config: &req, + } + s.jobs[jobID] = job + s.saveJobState(job) + + return &schema.QuantizationJobResponse{ + ID: jobID, + Status: "queued", + Message: result.Message, + }, nil +} + +// GetJob returns a quantization job by ID. +func (s *QuantizationService) GetJob(userID, jobID string) (*schema.QuantizationJob, error) { + s.mu.Lock() + defer s.mu.Unlock() + + job, ok := s.jobs[jobID] + if !ok { + return nil, fmt.Errorf("job not found: %s", jobID) + } + if userID != "" && job.UserID != userID { + return nil, fmt.Errorf("job not found: %s", jobID) + } + return job, nil +} + +// ListJobs returns all jobs for a user, sorted by creation time (newest first). +func (s *QuantizationService) ListJobs(userID string) []*schema.QuantizationJob { + s.mu.Lock() + defer s.mu.Unlock() + + var result []*schema.QuantizationJob + for _, job := range s.jobs { + if userID == "" || job.UserID == userID { + result = append(result, job) + } + } + + sort.Slice(result, func(i, j int) bool { + return result[i].CreatedAt > result[j].CreatedAt + }) + + return result +} + +// StopJob stops a running quantization job. +func (s *QuantizationService) StopJob(ctx context.Context, userID, jobID string) error { + s.mu.Lock() + job, ok := s.jobs[jobID] + if !ok { + s.mu.Unlock() + return fmt.Errorf("job not found: %s", jobID) + } + if userID != "" && job.UserID != userID { + s.mu.Unlock() + return fmt.Errorf("job not found: %s", jobID) + } + s.mu.Unlock() + + // Kill the backend process directly + stopModelID := job.ModelID + if stopModelID == "" { + stopModelID = job.Backend + "-quantize" + } + s.modelLoader.ShutdownModel(stopModelID) + + s.mu.Lock() + job.Status = "stopped" + job.Message = "Quantization stopped by user" + s.saveJobState(job) + s.mu.Unlock() + + return nil +} + +// DeleteJob removes a quantization job and its associated data from disk. +func (s *QuantizationService) DeleteJob(userID, jobID string) error { + s.mu.Lock() + job, ok := s.jobs[jobID] + if !ok { + s.mu.Unlock() + return fmt.Errorf("job not found: %s", jobID) + } + if userID != "" && job.UserID != userID { + s.mu.Unlock() + return fmt.Errorf("job not found: %s", jobID) + } + + // Reject deletion of actively running jobs + activeStatuses := map[string]bool{ + "queued": true, "downloading": true, "converting": true, "quantizing": true, + } + if activeStatuses[job.Status] { + s.mu.Unlock() + return fmt.Errorf("cannot delete job %s: currently %s (stop it first)", jobID, job.Status) + } + if job.ImportStatus == "importing" { + s.mu.Unlock() + return fmt.Errorf("cannot delete job %s: import in progress", jobID) + } + + importModelName := job.ImportModelName + delete(s.jobs, jobID) + s.mu.Unlock() + + // Remove job directory (state.json, output files) + jobDir := s.jobDir(jobID) + if err := os.RemoveAll(jobDir); err != nil { + xlog.Warn("Failed to remove quantization job directory", "job_id", jobID, "path", jobDir, "error", err) + } + + // If an imported model exists, clean it up too + if importModelName != "" { + modelsPath := s.appConfig.SystemState.Model.ModelsPath + modelDir := filepath.Join(modelsPath, importModelName) + configPath := filepath.Join(modelsPath, importModelName+".yaml") + + if err := os.RemoveAll(modelDir); err != nil { + xlog.Warn("Failed to remove imported model directory", "path", modelDir, "error", err) + } + if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) { + xlog.Warn("Failed to remove imported model config", "path", configPath, "error", err) + } + + // Reload model configs + if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil { + xlog.Warn("Failed to reload configs after delete", "error", err) + } + } + + xlog.Info("Deleted quantization job", "job_id", jobID) + return nil +} + +// StreamProgress opens a gRPC progress stream and calls the callback for each update. +func (s *QuantizationService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.QuantizationProgressEvent)) error { + s.mu.Lock() + job, ok := s.jobs[jobID] + if !ok { + s.mu.Unlock() + return fmt.Errorf("job not found: %s", jobID) + } + if userID != "" && job.UserID != userID { + s.mu.Unlock() + return fmt.Errorf("job not found: %s", jobID) + } + s.mu.Unlock() + + streamModelID := job.ModelID + if streamModelID == "" { + streamModelID = job.Backend + "-quantize" + } + backendModel, err := s.modelLoader.Load( + model.WithBackendString(job.Backend), + model.WithModel(job.Backend), + model.WithModelID(streamModelID), + ) + if err != nil { + return fmt.Errorf("failed to load backend: %w", err) + } + + return backendModel.QuantizationProgress(ctx, &pb.QuantizationProgressRequest{ + JobId: jobID, + }, func(update *pb.QuantizationProgressUpdate) { + // Update job status and persist + s.mu.Lock() + if j, ok := s.jobs[jobID]; ok { + // Don't let progress updates overwrite terminal states + isTerminal := j.Status == "stopped" || j.Status == "completed" || j.Status == "failed" + if !isTerminal { + j.Status = update.Status + } + if update.Message != "" { + j.Message = update.Message + } + if update.OutputFile != "" { + j.OutputFile = update.OutputFile + } + s.saveJobState(j) + } + s.mu.Unlock() + + // Convert extra metrics + extraMetrics := make(map[string]float32) + for k, v := range update.ExtraMetrics { + extraMetrics[k] = v + } + + event := &schema.QuantizationProgressEvent{ + JobID: update.JobId, + ProgressPercent: update.ProgressPercent, + Status: update.Status, + Message: update.Message, + OutputFile: update.OutputFile, + ExtraMetrics: extraMetrics, + } + callback(event) + }) +} + +// sanitizeQuantModelName replaces non-alphanumeric characters with hyphens and lowercases. +func sanitizeQuantModelName(s string) string { + re := regexp.MustCompile(`[^a-zA-Z0-9\-]`) + s = re.ReplaceAllString(s, "-") + s = regexp.MustCompile(`-+`).ReplaceAllString(s, "-") + s = strings.Trim(s, "-") + return strings.ToLower(s) +} + +// ImportModel imports a quantized model into LocalAI asynchronously. +func (s *QuantizationService) ImportModel(ctx context.Context, userID, jobID string, req schema.QuantizationImportRequest) (string, error) { + s.mu.Lock() + job, ok := s.jobs[jobID] + if !ok { + s.mu.Unlock() + return "", fmt.Errorf("job not found: %s", jobID) + } + if userID != "" && job.UserID != userID { + s.mu.Unlock() + return "", fmt.Errorf("job not found: %s", jobID) + } + if job.Status != "completed" { + s.mu.Unlock() + return "", fmt.Errorf("job %s is not completed (status: %s)", jobID, job.Status) + } + if job.ImportStatus == "importing" { + s.mu.Unlock() + return "", fmt.Errorf("import already in progress for job %s", jobID) + } + if job.OutputFile == "" { + s.mu.Unlock() + return "", fmt.Errorf("no output file for job %s", jobID) + } + s.mu.Unlock() + + // Compute model name + modelName := req.Name + if modelName == "" { + base := sanitizeQuantModelName(job.Model) + if base == "" { + base = "model" + } + shortID := jobID + if len(shortID) > 8 { + shortID = shortID[:8] + } + modelName = base + "-" + job.QuantizationType + "-" + shortID + } + + // Compute output path in models directory + modelsPath := s.appConfig.SystemState.Model.ModelsPath + outputPath := filepath.Join(modelsPath, modelName) + + // Check for name collision + configPath := filepath.Join(modelsPath, modelName+".yaml") + if err := utils.VerifyPath(modelName+".yaml", modelsPath); err != nil { + return "", fmt.Errorf("invalid model name: %w", err) + } + if _, err := os.Stat(configPath); err == nil { + return "", fmt.Errorf("model %q already exists, choose a different name", modelName) + } + + // Create output directory + if err := os.MkdirAll(outputPath, 0750); err != nil { + return "", fmt.Errorf("failed to create output directory: %w", err) + } + + // Set import status + s.mu.Lock() + job.ImportStatus = "importing" + job.ImportMessage = "" + job.ImportModelName = "" + s.saveJobState(job) + s.mu.Unlock() + + // Launch the import in a background goroutine + go func() { + s.setImportMessage(job, "Copying quantized model...") + + // Copy the output file to the models directory + srcFile := job.OutputFile + dstFile := filepath.Join(outputPath, filepath.Base(srcFile)) + + srcData, err := os.ReadFile(srcFile) + if err != nil { + s.setImportFailed(job, fmt.Sprintf("failed to read output file: %v", err)) + return + } + if err := os.WriteFile(dstFile, srcData, 0644); err != nil { + s.setImportFailed(job, fmt.Sprintf("failed to write model file: %v", err)) + return + } + + s.setImportMessage(job, "Generating model configuration...") + + // Auto-import: detect format and generate config + cfg, err := importers.ImportLocalPath(outputPath, modelName) + if err != nil { + s.setImportFailed(job, fmt.Sprintf("model copied to %s but config generation failed: %v", outputPath, err)) + return + } + + cfg.Name = modelName + + // Write YAML config + yamlData, err := yaml.Marshal(cfg) + if err != nil { + s.setImportFailed(job, fmt.Sprintf("failed to marshal config: %v", err)) + return + } + if err := os.WriteFile(configPath, yamlData, 0644); err != nil { + s.setImportFailed(job, fmt.Sprintf("failed to write config file: %v", err)) + return + } + + s.setImportMessage(job, "Registering model with LocalAI...") + + // Reload configs so the model is immediately available + if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil { + xlog.Warn("Failed to reload configs after import", "error", err) + } + if err := s.configLoader.Preload(modelsPath); err != nil { + xlog.Warn("Failed to preload after import", "error", err) + } + + xlog.Info("Quantized model imported and registered", "job_id", jobID, "model_name", modelName) + + s.mu.Lock() + job.ImportStatus = "completed" + job.ImportModelName = modelName + job.ImportMessage = "" + s.saveJobState(job) + s.mu.Unlock() + }() + + return modelName, nil +} + +// setImportMessage updates the import message and persists the job state. +func (s *QuantizationService) setImportMessage(job *schema.QuantizationJob, msg string) { + s.mu.Lock() + job.ImportMessage = msg + s.saveJobState(job) + s.mu.Unlock() +} + +// setImportFailed sets the import status to failed with a message. +func (s *QuantizationService) setImportFailed(job *schema.QuantizationJob, message string) { + xlog.Error("Quantization import failed", "job_id", job.ID, "error", message) + s.mu.Lock() + job.ImportStatus = "failed" + job.ImportMessage = message + s.saveJobState(job) + s.mu.Unlock() +} + +// GetOutputPath returns the path to the quantized model file and a download name. +func (s *QuantizationService) GetOutputPath(userID, jobID string) (string, string, error) { + s.mu.Lock() + job, ok := s.jobs[jobID] + if !ok { + s.mu.Unlock() + return "", "", fmt.Errorf("job not found: %s", jobID) + } + if userID != "" && job.UserID != userID { + s.mu.Unlock() + return "", "", fmt.Errorf("job not found: %s", jobID) + } + if job.Status != "completed" { + s.mu.Unlock() + return "", "", fmt.Errorf("job not completed (status: %s)", job.Status) + } + if job.OutputFile == "" { + s.mu.Unlock() + return "", "", fmt.Errorf("no output file for job %s", jobID) + } + outputFile := job.OutputFile + s.mu.Unlock() + + if _, err := os.Stat(outputFile); os.IsNotExist(err) { + return "", "", fmt.Errorf("output file not found: %s", outputFile) + } + + downloadName := filepath.Base(outputFile) + return outputFile, downloadName, nil +} diff --git a/docs/content/features/fine-tuning.md b/docs/content/features/fine-tuning.md index e5860416c..8dcbbd72a 100644 --- a/docs/content/features/fine-tuning.md +++ b/docs/content/features/fine-tuning.md @@ -13,15 +13,13 @@ LocalAI supports fine-tuning LLMs directly through the API and Web UI. Fine-tuni |---------|--------|-------------|-----------------|---------------| | **trl** | LLM fine-tuning | No (CPU or GPU) | SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO | LoRA, Full | -## Enabling Fine-Tuning +## Availability -Fine-tuning is disabled by default. Enable it with: +Fine-tuning is always enabled. When authentication is enabled, fine-tuning is a per-user feature (default OFF). Admins can enable it for specific users via the user management API. -```bash -LOCALAI_ENABLE_FINETUNING=true local-ai -``` - -When authentication is enabled, fine-tuning is a per-user feature (default OFF). Admins can enable it for specific users via the user management API. +{{% alert note %}} +This feature is **experimental** and may change in future releases. +{{% /alert %}} ## Quick Start diff --git a/docs/content/features/quantization.md b/docs/content/features/quantization.md new file mode 100644 index 000000000..8455145df --- /dev/null +++ b/docs/content/features/quantization.md @@ -0,0 +1,158 @@ ++++ +disableToc = false +title = "Model Quantization" +weight = 19 +url = '/features/quantization/' ++++ + +LocalAI supports model quantization directly through the API and Web UI. Quantization converts HuggingFace models to GGUF format and compresses them to smaller sizes for efficient inference with llama.cpp. + +{{% alert note %}} +This feature is **experimental** and may change in future releases. +{{% /alert %}} + +## Supported Backends + +| Backend | Description | Quantization Types | Platforms | +|---------|-------------|-------------------|-----------| +| **llama-cpp-quantization** | Downloads HF models, converts to GGUF, and quantizes using llama.cpp | q2_k, q3_k_s, q3_k_m, q3_k_l, q4_0, q4_k_s, q4_k_m, q5_0, q5_k_s, q5_k_m, q6_k, q8_0, f16 | CPU (Linux, macOS) | + +## Quick Start + +### 1. Start a quantization job + +```bash +curl -X POST http://localhost:8080/api/quantization/jobs \ + -H "Content-Type: application/json" \ + -d '{ + "model": "unsloth/functiongemma-270m-it", + "quantization_type": "q4_k_m" + }' +``` + +### 2. Monitor progress (SSE stream) + +```bash +curl -N http://localhost:8080/api/quantization/jobs/{job_id}/progress +``` + +### 3. Download the quantized model + +```bash +curl -o model.gguf http://localhost:8080/api/quantization/jobs/{job_id}/download +``` + +### 4. Or import it directly into LocalAI + +```bash +curl -X POST http://localhost:8080/api/quantization/jobs/{job_id}/import \ + -H "Content-Type: application/json" \ + -d '{ + "name": "my-quantized-model" + }' +``` + +## API Reference + +### Endpoints + +| Method | Path | Description | +|--------|------|-------------| +| `POST` | `/api/quantization/jobs` | Start a quantization job | +| `GET` | `/api/quantization/jobs` | List all jobs | +| `GET` | `/api/quantization/jobs/:id` | Get job details | +| `POST` | `/api/quantization/jobs/:id/stop` | Stop a running job | +| `DELETE` | `/api/quantization/jobs/:id` | Delete a job and its data | +| `GET` | `/api/quantization/jobs/:id/progress` | SSE progress stream | +| `POST` | `/api/quantization/jobs/:id/import` | Import quantized model into LocalAI | +| `GET` | `/api/quantization/jobs/:id/download` | Download quantized GGUF file | +| `GET` | `/api/quantization/backends` | List available quantization backends | + +### Job Request Fields + +| Field | Type | Description | +|-------|------|-------------| +| `model` | string | HuggingFace model ID or local path (required) | +| `backend` | string | Backend name (default: `llama-cpp-quantization`) | +| `quantization_type` | string | Quantization format (default: `q4_k_m`) | +| `extra_options` | map | Backend-specific options (see below) | + +### Extra Options + +| Key | Description | +|-----|-------------| +| `hf_token` | HuggingFace token for gated models | + +### Import Request Fields + +| Field | Type | Description | +|-------|------|-------------| +| `name` | string | Model name for LocalAI (auto-generated if empty) | + +### Job Status Values + +| Status | Description | +|--------|-------------| +| `queued` | Job created, waiting to start | +| `downloading` | Downloading model from HuggingFace | +| `converting` | Converting model to f16 GGUF | +| `quantizing` | Running quantization | +| `completed` | Quantization finished successfully | +| `failed` | Job failed (check message for details) | +| `stopped` | Job stopped by user | + +### Progress Stream + +The `GET /api/quantization/jobs/:id/progress` endpoint returns Server-Sent Events (SSE) with JSON payloads: + +```json +{ + "job_id": "abc-123", + "progress_percent": 65.0, + "status": "quantizing", + "message": "[ 234/ 567] quantizing blk.5.attn_k.weight ...", + "output_file": "", + "extra_metrics": {} +} +``` + +When the job completes, `output_file` contains the path to the quantized GGUF file and `extra_metrics` includes `file_size_mb`. + +## Quantization Types + +| Type | Size | Quality | Description | +|------|------|---------|-------------| +| `q2_k` | Smallest | Lowest | 2-bit quantization | +| `q3_k_s` | Very small | Low | 3-bit small | +| `q3_k_m` | Very small | Low | 3-bit medium | +| `q3_k_l` | Small | Low-medium | 3-bit large | +| `q4_0` | Small | Medium | 4-bit legacy | +| `q4_k_s` | Small | Medium | 4-bit small | +| `q4_k_m` | Small | **Good** | **4-bit medium (recommended)** | +| `q5_0` | Medium | Good | 5-bit legacy | +| `q5_k_s` | Medium | Good | 5-bit small | +| `q5_k_m` | Medium | Very good | 5-bit medium | +| `q6_k` | Large | Excellent | 6-bit | +| `q8_0` | Large | Near-lossless | 8-bit | +| `f16` | Largest | Lossless | 16-bit (no quantization, GGUF conversion only) | + +The UI also supports entering a custom quantization type string for any format supported by `llama-quantize`. + +## Web UI + +A "Quantize" page appears in the sidebar under the Tools section. The UI provides: + +1. **Job Configuration** — Select model, quantization type (dropdown with presets or custom input), backend, and HuggingFace token +2. **Progress Monitor** — Real-time progress bar and log output via SSE +3. **Jobs List** — View all quantization jobs with status, stop/delete actions +4. **Output** — Download the quantized GGUF file or import it directly into LocalAI for immediate use + +## Architecture + +Quantization uses the same gRPC backend architecture as fine-tuning: + +1. **Proto layer**: `QuantizationRequest`, `QuantizationProgress` (streaming), `StopQuantization` +2. **Python backend**: Downloads model, runs `convert_hf_to_gguf.py` and `llama-quantize` +3. **Go service**: Manages job lifecycle, state persistence, async import +4. **REST API**: HTTP endpoints with SSE progress streaming +5. **React UI**: Configuration form, real-time progress monitor, download/import panel diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index e010d127e..4a64ccbf8 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -70,4 +70,9 @@ type Backend interface { StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error) + + // Quantization + StartQuantization(ctx context.Context, in *pb.QuantizationRequest, opts ...grpc.CallOption) (*pb.QuantizationJobResult, error) + QuantizationProgress(ctx context.Context, in *pb.QuantizationProgressRequest, f func(update *pb.QuantizationProgressUpdate), opts ...grpc.CallOption) error + StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest, opts ...grpc.CallOption) (*pb.Result, error) } diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 481c60762..e9bea6cf9 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -140,6 +140,18 @@ func (llm *Base) ExportModel(*pb.ExportModelRequest) error { return fmt.Errorf("unimplemented") } +func (llm *Base) StartQuantization(*pb.QuantizationRequest) (*pb.QuantizationJobResult, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (llm *Base) QuantizationProgress(*pb.QuantizationProgressRequest, chan *pb.QuantizationProgressUpdate) error { + return fmt.Errorf("unimplemented") +} + +func (llm *Base) StopQuantization(*pb.QuantizationStopRequest) error { + return fmt.Errorf("unimplemented") +} + func memoryUsage() *pb.MemoryUsageData { mud := pb.MemoryUsageData{ Breakdown: make(map[string]uint64), diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index d3ef67814..253167ba0 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -768,6 +768,98 @@ func (c *Client) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opt return client.ExportModel(ctx, in, opts...) } +func (c *Client) StartQuantization(ctx context.Context, in *pb.QuantizationRequest, opts ...grpc.CallOption) (*pb.QuantizationJobResult, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB + grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB + )) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.StartQuantization(ctx, in, opts...) +} + +func (c *Client) QuantizationProgress(ctx context.Context, in *pb.QuantizationProgressRequest, f func(update *pb.QuantizationProgressUpdate), opts ...grpc.CallOption) error { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB + grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB + )) + if err != nil { + return err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + + stream, err := client.QuantizationProgress(ctx, in, opts...) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + update, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return err + } + f(update) + } + + return nil +} + +func (c *Client) StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest, opts ...grpc.CallOption) (*pb.Result, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB + grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB + )) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.StopQuantization(ctx, in, opts...) +} + func (c *Client) ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) { if !c.parallel { c.opMutex.Lock() diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index a7bc8f40d..f198f25aa 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -147,6 +147,22 @@ func (e *embedBackend) ExportModel(ctx context.Context, in *pb.ExportModelReques return e.s.ExportModel(ctx, in) } +func (e *embedBackend) StartQuantization(ctx context.Context, in *pb.QuantizationRequest, opts ...grpc.CallOption) (*pb.QuantizationJobResult, error) { + return e.s.StartQuantization(ctx, in) +} + +func (e *embedBackend) QuantizationProgress(ctx context.Context, in *pb.QuantizationProgressRequest, f func(update *pb.QuantizationProgressUpdate), opts ...grpc.CallOption) error { + bs := &embedBackendQuantizationProgressStream{ + ctx: ctx, + fn: f, + } + return e.s.QuantizationProgress(in, bs) +} + +func (e *embedBackend) StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest, opts ...grpc.CallOption) (*pb.Result, error) { + return e.s.StopQuantization(ctx, in) +} + var _ pb.Backend_FineTuneProgressServer = new(embedBackendFineTuneProgressStream) type embedBackendFineTuneProgressStream struct { @@ -185,6 +201,44 @@ func (e *embedBackendFineTuneProgressStream) RecvMsg(m any) error { return nil } +var _ pb.Backend_QuantizationProgressServer = new(embedBackendQuantizationProgressStream) + +type embedBackendQuantizationProgressStream struct { + ctx context.Context + fn func(update *pb.QuantizationProgressUpdate) +} + +func (e *embedBackendQuantizationProgressStream) Send(update *pb.QuantizationProgressUpdate) error { + e.fn(update) + return nil +} + +func (e *embedBackendQuantizationProgressStream) SetHeader(md metadata.MD) error { + return nil +} + +func (e *embedBackendQuantizationProgressStream) SendHeader(md metadata.MD) error { + return nil +} + +func (e *embedBackendQuantizationProgressStream) SetTrailer(md metadata.MD) { +} + +func (e *embedBackendQuantizationProgressStream) Context() context.Context { + return e.ctx +} + +func (e *embedBackendQuantizationProgressStream) SendMsg(m any) error { + if x, ok := m.(*pb.QuantizationProgressUpdate); ok { + return e.Send(x) + } + return nil +} + +func (e *embedBackendQuantizationProgressStream) RecvMsg(m any) error { + return nil +} + type embedBackendServerStream struct { ctx context.Context fn func(reply *pb.Reply) diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 6f1b5346a..8952a22da 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -42,6 +42,11 @@ type AIModel interface { StopFineTune(*pb.FineTuneStopRequest) error ListCheckpoints(*pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error) ExportModel(*pb.ExportModelRequest) error + + // Quantization + StartQuantization(*pb.QuantizationRequest) (*pb.QuantizationJobResult, error) + QuantizationProgress(*pb.QuantizationProgressRequest, chan *pb.QuantizationProgressUpdate) error + StopQuantization(*pb.QuantizationStopRequest) error } func newReply(s string) *pb.Reply { diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 5bb0ca98b..59d1bb4cb 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -377,6 +377,51 @@ func (s *server) ExportModel(ctx context.Context, in *pb.ExportModelRequest) (*p return &pb.Result{Message: "Model exported", Success: true}, nil } +func (s *server) StartQuantization(ctx context.Context, in *pb.QuantizationRequest) (*pb.QuantizationJobResult, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + res, err := s.llm.StartQuantization(in) + if err != nil { + return &pb.QuantizationJobResult{Success: false, Message: fmt.Sprintf("Error starting quantization: %s", err.Error())}, err + } + return res, nil +} + +func (s *server) QuantizationProgress(in *pb.QuantizationProgressRequest, stream pb.Backend_QuantizationProgressServer) error { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + updateChan := make(chan *pb.QuantizationProgressUpdate) + + done := make(chan bool) + go func() { + for update := range updateChan { + stream.Send(update) + } + done <- true + }() + + err := s.llm.QuantizationProgress(in, updateChan) + <-done + + return err +} + +func (s *server) StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest) (*pb.Result, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + err := s.llm.StopQuantization(in) + if err != nil { + return &pb.Result{Message: fmt.Sprintf("Error stopping quantization: %s", err.Error()), Success: false}, err + } + return &pb.Result{Message: "Quantization stopped", Success: true}, nil +} + func (s *server) ModelMetadata(ctx context.Context, in *pb.ModelOptions) (*pb.ModelMetadataResponse, error) { if s.llm.Locking() { s.llm.Lock()