mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 13:15:51 -04:00
feat(quantization): add quantization backend (#9096)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
4b183b7bb6
commit
f7e8d9e791
16
.github/workflows/backend.yml
vendored
16
.github/workflows/backend.yml
vendored
@@ -131,6 +131,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-llama-cpp-quantization'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "llama-cpp-quantization"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -2412,6 +2425,9 @@ jobs:
|
||||
tag-suffix: "-metal-darwin-arm64-local-store"
|
||||
build-type: "metal"
|
||||
lang: "go"
|
||||
- backend: "llama-cpp-quantization"
|
||||
tag-suffix: "-metal-darwin-arm64-llama-cpp-quantization"
|
||||
build-type: "mps"
|
||||
with:
|
||||
backend: ${{ matrix.backend }}
|
||||
build-type: ${{ matrix.build-type }}
|
||||
|
||||
27
.github/workflows/test-extra.yml
vendored
27
.github/workflows/test-extra.yml
vendored
@@ -383,6 +383,33 @@ jobs:
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm
|
||||
make --jobs=5 --output-sync=target -C backend/python/voxcpm test
|
||||
tests-llama-cpp-quantization:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential cmake curl git python3-pip
|
||||
# Install UV
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
pip install --user --no-cache-dir grpcio-tools==1.64.1
|
||||
- name: Build llama-quantize from llama.cpp
|
||||
run: |
|
||||
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git /tmp/llama.cpp
|
||||
cmake -B /tmp/llama.cpp/build -S /tmp/llama.cpp -DGGML_NATIVE=OFF
|
||||
cmake --build /tmp/llama.cpp/build --target llama-quantize -j$(nproc)
|
||||
sudo cp /tmp/llama.cpp/build/bin/llama-quantize /usr/local/bin/
|
||||
- name: Install backend
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization
|
||||
- name: Test llama-cpp-quantization
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/llama-cpp-quantization test
|
||||
tests-acestep-cpp:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
6
Makefile
6
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl backends/llama-cpp-quantization
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -575,6 +575,7 @@ BACKEND_WHISPERX = whisperx|python|.|false|true
|
||||
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
||||
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||
BACKEND_TRL = trl|python|.|false|true
|
||||
BACKEND_LLAMA_CPP_QUANTIZATION = llama-cpp-quantization|python|.|false|true
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
@@ -633,12 +634,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP_QUANTIZATION)))
|
||||
|
||||
# Pattern rule for docker-save targets
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl docker-build-llama-cpp-quantization
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
@@ -46,6 +46,11 @@ service Backend {
|
||||
rpc StopFineTune(FineTuneStopRequest) returns (Result) {}
|
||||
rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {}
|
||||
rpc ExportModel(ExportModelRequest) returns (Result) {}
|
||||
|
||||
// Quantization RPCs
|
||||
rpc StartQuantization(QuantizationRequest) returns (QuantizationJobResult) {}
|
||||
rpc QuantizationProgress(QuantizationProgressRequest) returns (stream QuantizationProgressUpdate) {}
|
||||
rpc StopQuantization(QuantizationStopRequest) returns (Result) {}
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -637,3 +642,36 @@ message ExportModelRequest {
|
||||
string model = 5; // base model name (for merge operations)
|
||||
map<string, string> extra_options = 6;
|
||||
}
|
||||
|
||||
// Quantization messages
|
||||
|
||||
message QuantizationRequest {
|
||||
string model = 1; // HF model name or local path
|
||||
string quantization_type = 2; // q4_k_m, q5_k_m, q8_0, f16, etc.
|
||||
string output_dir = 3; // where to write output files
|
||||
string job_id = 4; // client-assigned job ID
|
||||
map<string, string> extra_options = 5; // hf_token, custom flags, etc.
|
||||
}
|
||||
|
||||
message QuantizationJobResult {
|
||||
string job_id = 1;
|
||||
bool success = 2;
|
||||
string message = 3;
|
||||
}
|
||||
|
||||
message QuantizationProgressRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
message QuantizationProgressUpdate {
|
||||
string job_id = 1;
|
||||
float progress_percent = 2;
|
||||
string status = 3; // queued, downloading, converting, quantizing, completed, failed, stopped
|
||||
string message = 4;
|
||||
string output_file = 5; // set when completed — path to the output GGUF file
|
||||
map<string, float> extra_metrics = 6; // e.g. file_size_mb, compression_ratio
|
||||
}
|
||||
|
||||
message QuantizationStopRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
@@ -3081,3 +3081,31 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cublas-cuda13-trl
|
||||
## llama.cpp quantization backend
|
||||
- &llama-cpp-quantization
|
||||
name: "llama-cpp-quantization"
|
||||
alias: "llama-cpp-quantization"
|
||||
license: mit
|
||||
icon: https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png
|
||||
description: |
|
||||
Model quantization backend using llama.cpp. Downloads HuggingFace models, converts them to GGUF format,
|
||||
and quantizes them to various formats (q4_k_m, q5_k_m, q8_0, f16, etc.).
|
||||
urls:
|
||||
- https://github.com/ggml-org/llama.cpp
|
||||
tags:
|
||||
- quantization
|
||||
- GGUF
|
||||
- CPU
|
||||
capabilities:
|
||||
default: "cpu-llama-cpp-quantization"
|
||||
metal: "metal-darwin-arm64-llama-cpp-quantization"
|
||||
- !!merge <<: *llama-cpp-quantization
|
||||
name: "cpu-llama-cpp-quantization"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp-quantization"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-llama-cpp-quantization
|
||||
- !!merge <<: *llama-cpp-quantization
|
||||
name: "metal-darwin-arm64-llama-cpp-quantization"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp-quantization"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-llama-cpp-quantization
|
||||
|
||||
26
backend/python/llama-cpp-quantization/Makefile
Normal file
26
backend/python/llama-cpp-quantization/Makefile
Normal file
@@ -0,0 +1,26 @@
|
||||
# Version of llama.cpp to fetch convert_hf_to_gguf.py from
|
||||
LLAMA_CPP_CONVERT_VERSION ?= master
|
||||
|
||||
.PHONY: llama-cpp-quantization
|
||||
llama-cpp-quantization:
|
||||
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: llama-cpp-quantization
|
||||
@echo "Running llama-cpp-quantization..."
|
||||
bash run.sh
|
||||
@echo "llama-cpp-quantization run."
|
||||
|
||||
.PHONY: test
|
||||
test: llama-cpp-quantization
|
||||
@echo "Testing llama-cpp-quantization..."
|
||||
bash test.sh
|
||||
@echo "llama-cpp-quantization tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
420
backend/python/llama-cpp-quantization/backend.py
Normal file
420
backend/python/llama-cpp-quantization/backend.py
Normal file
@@ -0,0 +1,420 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
llama.cpp quantization backend for LocalAI.
|
||||
|
||||
Downloads HuggingFace models, converts them to GGUF format using
|
||||
convert_hf_to_gguf.py, and quantizes using llama-quantize.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
|
||||
|
||||
|
||||
class ActiveJob:
|
||||
"""Tracks a running quantization job."""
|
||||
def __init__(self, job_id):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = queue.Queue()
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = None
|
||||
self.process = None # subprocess handle for killing
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self):
|
||||
self.jobs = {} # job_id -> ActiveJob
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=b"OK")
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Accept LoadModel — actual work happens in StartQuantization."""
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def StartQuantization(self, request, context):
|
||||
job_id = request.job_id
|
||||
if job_id in self.jobs:
|
||||
return backend_pb2.QuantizationJobResult(
|
||||
job_id=job_id,
|
||||
success=False,
|
||||
message=f"Job {job_id} already exists",
|
||||
)
|
||||
|
||||
job = ActiveJob(job_id)
|
||||
self.jobs[job_id] = job
|
||||
|
||||
job.thread = threading.Thread(
|
||||
target=self._do_quantization,
|
||||
args=(job, request),
|
||||
daemon=True,
|
||||
)
|
||||
job.thread.start()
|
||||
|
||||
return backend_pb2.QuantizationJobResult(
|
||||
job_id=job_id,
|
||||
success=True,
|
||||
message="Quantization job started",
|
||||
)
|
||||
|
||||
def _send_progress(self, job, status, message, progress_percent=0.0, output_file="", extra_metrics=None):
|
||||
update = backend_pb2.QuantizationProgressUpdate(
|
||||
job_id=job.job_id,
|
||||
progress_percent=progress_percent,
|
||||
status=status,
|
||||
message=message,
|
||||
output_file=output_file,
|
||||
extra_metrics=extra_metrics or {},
|
||||
)
|
||||
job.progress_queue.put(update)
|
||||
|
||||
def _do_quantization(self, job, request):
|
||||
try:
|
||||
model = request.model
|
||||
quant_type = request.quantization_type or "q4_k_m"
|
||||
output_dir = request.output_dir
|
||||
extra_options = dict(request.extra_options) if request.extra_options else {}
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped before starting")
|
||||
return
|
||||
|
||||
# Step 1: Download / resolve model
|
||||
self._send_progress(job, "downloading", f"Resolving model: {model}", progress_percent=0.0)
|
||||
|
||||
model_path = self._resolve_model(job, model, output_dir, extra_options)
|
||||
if model_path is None:
|
||||
return # error already sent
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped during download")
|
||||
return
|
||||
|
||||
# Step 2: Convert to f16 GGUF
|
||||
self._send_progress(job, "converting", "Converting model to GGUF (f16)...", progress_percent=30.0)
|
||||
|
||||
f16_gguf_path = os.path.join(output_dir, "model-f16.gguf")
|
||||
if not self._convert_to_gguf(job, model_path, f16_gguf_path, extra_options):
|
||||
return # error already sent
|
||||
|
||||
if job.stop_event.is_set():
|
||||
self._send_progress(job, "stopped", "Job stopped during conversion")
|
||||
return
|
||||
|
||||
# Step 3: Quantize
|
||||
# If the user requested f16, skip quantization — the f16 GGUF is the final output
|
||||
if quant_type.lower() in ("f16", "fp16"):
|
||||
output_file = f16_gguf_path
|
||||
self._send_progress(
|
||||
job, "completed",
|
||||
f"Model converted to f16 GGUF: {output_file}",
|
||||
progress_percent=100.0,
|
||||
output_file=output_file,
|
||||
extra_metrics=self._file_metrics(output_file),
|
||||
)
|
||||
return
|
||||
|
||||
output_file = os.path.join(output_dir, f"model-{quant_type}.gguf")
|
||||
self._send_progress(job, "quantizing", f"Quantizing to {quant_type}...", progress_percent=50.0)
|
||||
|
||||
if not self._quantize(job, f16_gguf_path, output_file, quant_type):
|
||||
return # error already sent
|
||||
|
||||
# Clean up f16 intermediate file to save disk space
|
||||
try:
|
||||
os.remove(f16_gguf_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
self._send_progress(
|
||||
job, "completed",
|
||||
f"Quantization complete: {quant_type}",
|
||||
progress_percent=100.0,
|
||||
output_file=output_file,
|
||||
extra_metrics=self._file_metrics(output_file),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
|
||||
|
||||
def _resolve_model(self, job, model, output_dir, extra_options):
|
||||
"""Download model from HuggingFace or return local path."""
|
||||
# If it's a local path that exists, use it directly
|
||||
if os.path.isdir(model):
|
||||
return model
|
||||
|
||||
# If it looks like a GGUF file path, use it directly
|
||||
if os.path.isfile(model) and model.endswith(".gguf"):
|
||||
return model
|
||||
|
||||
# Download from HuggingFace
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
hf_token = extra_options.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
cache_dir = os.path.join(output_dir, "hf_cache")
|
||||
|
||||
self._send_progress(job, "downloading", f"Downloading {model} from HuggingFace...", progress_percent=5.0)
|
||||
|
||||
local_path = snapshot_download(
|
||||
repo_id=model,
|
||||
cache_dir=cache_dir,
|
||||
token=hf_token,
|
||||
ignore_patterns=["*.md", "*.txt", "LICENSE*", ".gitattributes"],
|
||||
)
|
||||
|
||||
self._send_progress(job, "downloading", f"Downloaded {model}", progress_percent=25.0)
|
||||
return local_path
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "gated" in error_msg.lower() or "access" in error_msg.lower():
|
||||
self._send_progress(
|
||||
job, "failed",
|
||||
f"Access denied for {model}. This model may be gated — "
|
||||
f"please accept the license at https://huggingface.co/{model} "
|
||||
f"and provide your HF token in extra_options.",
|
||||
)
|
||||
else:
|
||||
self._send_progress(job, "failed", f"Failed to download model: {error_msg}")
|
||||
return None
|
||||
|
||||
def _convert_to_gguf(self, job, model_path, output_path, extra_options):
|
||||
"""Convert HF model to f16 GGUF using convert_hf_to_gguf.py."""
|
||||
# If the model_path is already a GGUF file, just use it as-is
|
||||
if isinstance(model_path, str) and model_path.endswith(".gguf"):
|
||||
# Copy or symlink the GGUF file
|
||||
import shutil
|
||||
shutil.copy2(model_path, output_path)
|
||||
return True
|
||||
|
||||
# Find convert_hf_to_gguf.py
|
||||
convert_script = self._find_convert_script()
|
||||
if convert_script is None:
|
||||
self._send_progress(job, "failed", "convert_hf_to_gguf.py not found. Install it via the backend's install.sh.")
|
||||
return False
|
||||
|
||||
cmd = [
|
||||
sys.executable, convert_script,
|
||||
model_path,
|
||||
"--outfile", output_path,
|
||||
"--outtype", "f16",
|
||||
]
|
||||
|
||||
self._send_progress(job, "converting", "Running convert_hf_to_gguf.py...", progress_percent=35.0)
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
job.process = process
|
||||
|
||||
for line in process.stdout:
|
||||
line = line.strip()
|
||||
if line:
|
||||
self._send_progress(job, "converting", line, progress_percent=40.0)
|
||||
if job.stop_event.is_set():
|
||||
process.kill()
|
||||
self._send_progress(job, "stopped", "Job stopped during conversion")
|
||||
return False
|
||||
|
||||
process.wait()
|
||||
job.process = None
|
||||
|
||||
if process.returncode != 0:
|
||||
self._send_progress(job, "failed", f"convert_hf_to_gguf.py failed with exit code {process.returncode}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Conversion failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def _quantize(self, job, input_path, output_path, quant_type):
|
||||
"""Quantize a GGUF file using llama-quantize."""
|
||||
quantize_bin = self._find_quantize_binary()
|
||||
if quantize_bin is None:
|
||||
self._send_progress(job, "failed", "llama-quantize binary not found. Ensure it is installed and in PATH.")
|
||||
return False
|
||||
|
||||
cmd = [quantize_bin, input_path, output_path, quant_type]
|
||||
|
||||
self._send_progress(job, "quantizing", f"Running llama-quantize ({quant_type})...", progress_percent=55.0)
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
job.process = process
|
||||
|
||||
for line in process.stdout:
|
||||
line = line.strip()
|
||||
if line:
|
||||
# Try to parse progress from llama-quantize output
|
||||
progress = self._parse_quantize_progress(line)
|
||||
pct = 55.0 + (progress * 0.40) if progress else 60.0
|
||||
self._send_progress(job, "quantizing", line, progress_percent=pct)
|
||||
if job.stop_event.is_set():
|
||||
process.kill()
|
||||
self._send_progress(job, "stopped", "Job stopped during quantization")
|
||||
return False
|
||||
|
||||
process.wait()
|
||||
job.process = None
|
||||
|
||||
if process.returncode != 0:
|
||||
self._send_progress(job, "failed", f"llama-quantize failed with exit code {process.returncode}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._send_progress(job, "failed", f"Quantization failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def _parse_quantize_progress(self, line):
|
||||
"""Try to parse a progress percentage from llama-quantize output."""
|
||||
# llama-quantize typically outputs lines like:
|
||||
# [ 123/ 1234] quantizing blk.0.attn_k.weight ...
|
||||
match = re.search(r'\[\s*(\d+)\s*/\s*(\d+)\]', line)
|
||||
if match:
|
||||
current = int(match.group(1))
|
||||
total = int(match.group(2))
|
||||
if total > 0:
|
||||
return current / total
|
||||
return None
|
||||
|
||||
def _find_convert_script(self):
|
||||
"""Find convert_hf_to_gguf.py in known locations."""
|
||||
candidates = [
|
||||
# Same directory as this backend
|
||||
os.path.join(os.path.dirname(__file__), "convert_hf_to_gguf.py"),
|
||||
# Installed via install.sh
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "convert_hf_to_gguf.py"),
|
||||
]
|
||||
|
||||
# Also check if it's on PATH
|
||||
import shutil
|
||||
path_script = shutil.which("convert_hf_to_gguf.py")
|
||||
if path_script:
|
||||
candidates.append(path_script)
|
||||
|
||||
for candidate in candidates:
|
||||
if os.path.isfile(candidate):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
def _find_quantize_binary(self):
|
||||
"""Find llama-quantize binary."""
|
||||
import shutil
|
||||
|
||||
# Check common names on PATH
|
||||
for name in ["llama-quantize", "quantize"]:
|
||||
path = shutil.which(name)
|
||||
if path:
|
||||
return path
|
||||
|
||||
# Check in the backend directory (built by install.sh)
|
||||
backend_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
for name in ["llama-quantize", "quantize"]:
|
||||
candidate = os.path.join(backend_dir, name)
|
||||
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
def _file_metrics(self, filepath):
|
||||
"""Return file size metrics."""
|
||||
try:
|
||||
size_bytes = os.path.getsize(filepath)
|
||||
return {"file_size_mb": size_bytes / (1024 * 1024)}
|
||||
except OSError:
|
||||
return {}
|
||||
|
||||
def QuantizationProgress(self, request, context):
|
||||
job_id = request.job_id
|
||||
job = self.jobs.get(job_id)
|
||||
if job is None:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Job {job_id} not found")
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
update = job.progress_queue.get(timeout=1.0)
|
||||
yield update
|
||||
# If this is a terminal status, stop streaming
|
||||
if update.status in ("completed", "failed", "stopped"):
|
||||
break
|
||||
except queue.Empty:
|
||||
# Check if the thread is still alive
|
||||
if job.thread and not job.thread.is_alive():
|
||||
# Thread finished but no terminal update — drain queue
|
||||
while not job.progress_queue.empty():
|
||||
update = job.progress_queue.get_nowait()
|
||||
yield update
|
||||
break
|
||||
# Check if client disconnected
|
||||
if context.is_active() is False:
|
||||
break
|
||||
|
||||
def StopQuantization(self, request, context):
|
||||
job_id = request.job_id
|
||||
job = self.jobs.get(job_id)
|
||||
if job is None:
|
||||
return backend_pb2.Result(success=False, message=f"Job {job_id} not found")
|
||||
|
||||
job.stop_event.set()
|
||||
if job.process:
|
||||
try:
|
||||
job.process.kill()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return backend_pb2.Result(success=True, message="Stop signal sent")
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print(f"Quantization backend listening on {address}", file=sys.stderr, flush=True)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="llama.cpp quantization gRPC backend")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||
args = parser.parse_args()
|
||||
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: sys.exit(0))
|
||||
serve(args.addr)
|
||||
58
backend/python/llama-cpp-quantization/install.sh
Executable file
58
backend/python/llama-cpp-quantization/install.sh
Executable file
@@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade "
|
||||
installRequirements
|
||||
|
||||
# Fetch convert_hf_to_gguf.py from llama.cpp
|
||||
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
||||
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
||||
if [ ! -f "${CONVERT_SCRIPT}" ]; then
|
||||
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
curl -L --fail --retry 3 \
|
||||
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
|
||||
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py."
|
||||
fi
|
||||
|
||||
# Install gguf package from the same llama.cpp commit to keep them in sync
|
||||
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
|
||||
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
if [ "x${USE_PIP:-}" == "xtrue" ]; then
|
||||
pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
pip install "gguf>=0.16.0"
|
||||
}
|
||||
else
|
||||
uv pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
uv pip install "gguf>=0.16.0"
|
||||
}
|
||||
fi
|
||||
|
||||
# Build llama-quantize from llama.cpp if not already present
|
||||
QUANTIZE_BIN="${EDIR}/llama-quantize"
|
||||
if [ ! -x "${QUANTIZE_BIN}" ] && ! command -v llama-quantize &>/dev/null; then
|
||||
if command -v cmake &>/dev/null; then
|
||||
echo "Building llama-quantize from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
LLAMA_CPP_SRC="${EDIR}/llama.cpp"
|
||||
if [ ! -d "${LLAMA_CPP_SRC}" ]; then
|
||||
git clone --depth 1 --branch "${LLAMA_CPP_CONVERT_VERSION}" \
|
||||
https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}" 2>/dev/null || \
|
||||
git clone --depth 1 https://github.com/ggml-org/llama.cpp.git "${LLAMA_CPP_SRC}"
|
||||
fi
|
||||
cmake -B "${LLAMA_CPP_SRC}/build" -S "${LLAMA_CPP_SRC}" -DGGML_NATIVE=OFF -DBUILD_SHARED_LIBS=OFF
|
||||
cmake --build "${LLAMA_CPP_SRC}/build" --target llama-quantize -j"$(nproc 2>/dev/null || echo 2)"
|
||||
cp "${LLAMA_CPP_SRC}/build/bin/llama-quantize" "${QUANTIZE_BIN}"
|
||||
chmod +x "${QUANTIZE_BIN}"
|
||||
echo "Built llama-quantize at ${QUANTIZE_BIN}"
|
||||
else
|
||||
echo "Warning: cmake not found — llama-quantize will not be available. Install cmake or provide llama-quantize on PATH."
|
||||
fi
|
||||
fi
|
||||
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.10.0
|
||||
transformers>=4.56.2
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
@@ -0,0 +1,4 @@
|
||||
torch==2.10.0
|
||||
transformers>=4.56.2
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
3
backend/python/llama-cpp-quantization/requirements.txt
Normal file
3
backend/python/llama-cpp-quantization/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
10
backend/python/llama-cpp-quantization/run.sh
Executable file
10
backend/python/llama-cpp-quantization/run.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
99
backend/python/llama-cpp-quantization/test.py
Normal file
99
backend/python/llama-cpp-quantization/test.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Test script for the llama-cpp-quantization gRPC backend.
|
||||
|
||||
Downloads a small model (functiongemma-270m-it), converts it to GGUF,
|
||||
and quantizes it to q4_k_m.
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
SERVER_ADDR = "localhost:50051"
|
||||
# Small model for CI testing (~540MB)
|
||||
TEST_MODEL = "unsloth/functiongemma-270m-it"
|
||||
|
||||
|
||||
class TestQuantizationBackend(unittest.TestCase):
|
||||
"""Tests for the llama-cpp-quantization gRPC service."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", SERVER_ADDR]
|
||||
)
|
||||
time.sleep(5)
|
||||
cls.output_dir = tempfile.mkdtemp(prefix="quantize-test-")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.service.kill()
|
||||
cls.service.wait()
|
||||
# Clean up output directory
|
||||
if os.path.isdir(cls.output_dir):
|
||||
shutil.rmtree(cls.output_dir, ignore_errors=True)
|
||||
|
||||
def _channel(self):
|
||||
return grpc.insecure_channel(SERVER_ADDR)
|
||||
|
||||
def test_01_health(self):
|
||||
"""Test that the server starts and responds to health checks."""
|
||||
with self._channel() as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b"OK")
|
||||
|
||||
def test_02_quantize_small_model(self):
|
||||
"""Download, convert, and quantize functiongemma-270m-it to q4_k_m."""
|
||||
with self._channel() as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
|
||||
job_id = "test-quantize-001"
|
||||
|
||||
# Start quantization
|
||||
result = stub.StartQuantization(
|
||||
backend_pb2.QuantizationRequest(
|
||||
model=TEST_MODEL,
|
||||
quantization_type="q4_k_m",
|
||||
output_dir=self.output_dir,
|
||||
job_id=job_id,
|
||||
)
|
||||
)
|
||||
self.assertTrue(result.success, f"StartQuantization failed: {result.message}")
|
||||
self.assertEqual(result.job_id, job_id)
|
||||
|
||||
# Stream progress until completion
|
||||
final_status = None
|
||||
output_file = None
|
||||
for update in stub.QuantizationProgress(
|
||||
backend_pb2.QuantizationProgressRequest(job_id=job_id)
|
||||
):
|
||||
print(f" [{update.status}] {update.progress_percent:.1f}% - {update.message}")
|
||||
final_status = update.status
|
||||
if update.output_file:
|
||||
output_file = update.output_file
|
||||
|
||||
self.assertEqual(final_status, "completed", f"Expected completed, got {final_status}")
|
||||
self.assertIsNotNone(output_file, "No output_file in progress updates")
|
||||
self.assertTrue(os.path.isfile(output_file), f"Output file not found: {output_file}")
|
||||
|
||||
# Verify the output is a valid GGUF file (starts with "GGUF" magic)
|
||||
with open(output_file, "rb") as f:
|
||||
magic = f.read(4)
|
||||
self.assertEqual(magic, b"GGUF", f"Output file does not have GGUF magic: {magic!r}")
|
||||
|
||||
# Verify reasonable file size (q4_k_m of 270M model should be ~150-400MB)
|
||||
size_mb = os.path.getsize(output_file) / (1024 * 1024)
|
||||
print(f" Output file size: {size_mb:.1f} MB")
|
||||
self.assertGreater(size_mb, 10, "Output file suspiciously small")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
11
backend/python/llama-cpp-quantization/test.sh
Executable file
11
backend/python/llama-cpp-quantization/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -121,9 +121,6 @@ type RunCMD struct {
|
||||
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
|
||||
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
|
||||
|
||||
// Fine-tuning
|
||||
EnableFineTuning bool `env:"LOCALAI_ENABLE_FINETUNING" default:"false" help:"Enable fine-tuning support" group:"finetuning"`
|
||||
|
||||
// Authentication
|
||||
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
|
||||
AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"`
|
||||
@@ -329,11 +326,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.WithAgentHubURL(r.AgentHubURL))
|
||||
}
|
||||
|
||||
// Fine-tuning
|
||||
if r.EnableFineTuning {
|
||||
opts = append(opts, config.EnableFineTuning)
|
||||
}
|
||||
|
||||
// Authentication
|
||||
authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != ""
|
||||
if authEnabled {
|
||||
|
||||
@@ -97,9 +97,6 @@ type ApplicationConfig struct {
|
||||
// Agent Pool (LocalAGI integration)
|
||||
AgentPool AgentPoolConfig
|
||||
|
||||
// Fine-tuning
|
||||
FineTuning FineTuningConfig
|
||||
|
||||
// Authentication & Authorization
|
||||
Auth AuthConfig
|
||||
}
|
||||
@@ -145,11 +142,6 @@ type AgentPoolConfig struct {
|
||||
AgentHubURL string // default: "https://agenthub.localai.io"
|
||||
}
|
||||
|
||||
// FineTuningConfig holds configuration for fine-tuning support.
|
||||
type FineTuningConfig struct {
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
type AppOption func(*ApplicationConfig)
|
||||
|
||||
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
@@ -741,12 +733,6 @@ func WithAgentHubURL(url string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// Fine-tuning options
|
||||
|
||||
var EnableFineTuning = func(o *ApplicationConfig) {
|
||||
o.FineTuning.Enabled = true
|
||||
}
|
||||
|
||||
// Auth options
|
||||
|
||||
func WithAuthEnabled(enabled bool) AppOption {
|
||||
|
||||
@@ -304,15 +304,22 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw, mcpMw)
|
||||
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
|
||||
// Fine-tuning routes
|
||||
if application.ApplicationConfig().FineTuning.Enabled {
|
||||
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
|
||||
ftService := services.NewFineTuneService(
|
||||
application.ApplicationConfig(),
|
||||
application.ModelLoader(),
|
||||
application.ModelConfigLoader(),
|
||||
)
|
||||
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
|
||||
}
|
||||
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
|
||||
ftService := services.NewFineTuneService(
|
||||
application.ApplicationConfig(),
|
||||
application.ModelLoader(),
|
||||
application.ModelConfigLoader(),
|
||||
)
|
||||
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
|
||||
|
||||
// Quantization routes
|
||||
quantizationMw := auth.RequireFeature(application.AuthDB(), auth.FeatureQuantization)
|
||||
qService := services.NewQuantizationService(
|
||||
application.ApplicationConfig(),
|
||||
application.ModelLoader(),
|
||||
application.ModelConfigLoader(),
|
||||
)
|
||||
routes.RegisterQuantizationRoutes(e, qService, application.ApplicationConfig(), quantizationMw)
|
||||
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
||||
|
||||
@@ -97,6 +97,16 @@ var RouteFeatureRegistry = []RouteFeature{
|
||||
{"POST", "/api/fine-tuning/jobs/:id/export", FeatureFineTuning},
|
||||
{"GET", "/api/fine-tuning/jobs/:id/download", FeatureFineTuning},
|
||||
{"POST", "/api/fine-tuning/datasets", FeatureFineTuning},
|
||||
|
||||
// Quantization
|
||||
{"POST", "/api/quantization/jobs", FeatureQuantization},
|
||||
{"GET", "/api/quantization/jobs", FeatureQuantization},
|
||||
{"GET", "/api/quantization/jobs/:id", FeatureQuantization},
|
||||
{"POST", "/api/quantization/jobs/:id/stop", FeatureQuantization},
|
||||
{"DELETE", "/api/quantization/jobs/:id", FeatureQuantization},
|
||||
{"GET", "/api/quantization/jobs/:id/progress", FeatureQuantization},
|
||||
{"POST", "/api/quantization/jobs/:id/import", FeatureQuantization},
|
||||
{"GET", "/api/quantization/jobs/:id/download", FeatureQuantization},
|
||||
}
|
||||
|
||||
// FeatureMeta describes a feature for the admin API/UI.
|
||||
@@ -120,6 +130,7 @@ func AgentFeatureMetas() []FeatureMeta {
|
||||
func GeneralFeatureMetas() []FeatureMeta {
|
||||
return []FeatureMeta{
|
||||
{FeatureFineTuning, "Fine-Tuning", false},
|
||||
{FeatureQuantization, "Quantization", false},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,8 @@ const (
|
||||
FeatureMCPJobs = "mcp_jobs"
|
||||
|
||||
// General features (default OFF for new users)
|
||||
FeatureFineTuning = "fine_tuning"
|
||||
FeatureFineTuning = "fine_tuning"
|
||||
FeatureQuantization = "quantization"
|
||||
|
||||
// API features (default ON for new users)
|
||||
FeatureChat = "chat"
|
||||
@@ -56,7 +57,7 @@ const (
|
||||
var AgentFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs}
|
||||
|
||||
// GeneralFeatures lists general features (default OFF).
|
||||
var GeneralFeatures = []string{FeatureFineTuning}
|
||||
var GeneralFeatures = []string{FeatureFineTuning, FeatureQuantization}
|
||||
|
||||
// APIFeatures lists API endpoint features (default ON).
|
||||
var APIFeatures = []string{
|
||||
|
||||
243
core/http/endpoints/localai/quantization.go
Normal file
243
core/http/endpoints/localai/quantization.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// StartQuantizationJobEndpoint starts a new quantization job.
|
||||
func StartQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
|
||||
var req schema.QuantizationJobRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "Invalid request: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "model is required",
|
||||
})
|
||||
}
|
||||
|
||||
resp, err := qService.StartJob(c.Request().Context(), userID, req)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusCreated, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// ListQuantizationJobsEndpoint lists quantization jobs for the current user.
|
||||
func ListQuantizationJobsEndpoint(qService *services.QuantizationService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobs := qService.ListJobs(userID)
|
||||
if jobs == nil {
|
||||
jobs = []*schema.QuantizationJob{}
|
||||
}
|
||||
return c.JSON(http.StatusOK, jobs)
|
||||
}
|
||||
}
|
||||
|
||||
// GetQuantizationJobEndpoint gets a specific quantization job.
|
||||
func GetQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
job, err := qService.GetJob(userID, jobID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, job)
|
||||
}
|
||||
}
|
||||
|
||||
// StopQuantizationJobEndpoint stops a running quantization job.
|
||||
func StopQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
err := qService.StopJob(c.Request().Context(), userID, jobID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]string{
|
||||
"status": "stopped",
|
||||
"message": "Quantization job stopped",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteQuantizationJobEndpoint deletes a quantization job and its data.
|
||||
func DeleteQuantizationJobEndpoint(qService *services.QuantizationService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
err := qService.DeleteJob(userID, jobID)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
} else if strings.Contains(err.Error(), "cannot delete") {
|
||||
status = http.StatusConflict
|
||||
}
|
||||
return c.JSON(status, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]string{
|
||||
"status": "deleted",
|
||||
"message": "Quantization job deleted",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// QuantizationProgressEndpoint streams progress updates via SSE.
|
||||
func QuantizationProgressEndpoint(qService *services.QuantizationService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
// Set SSE headers
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
|
||||
err := qService.StreamProgress(c.Request().Context(), userID, jobID, func(event *schema.QuantizationProgressEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(c.Response(), "data: %s\n\n", data)
|
||||
c.Response().Flush()
|
||||
})
|
||||
if err != nil {
|
||||
// If headers already sent, we can't send a JSON error
|
||||
fmt.Fprintf(c.Response(), "data: {\"status\":\"error\",\"message\":%q}\n\n", err.Error())
|
||||
c.Response().Flush()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ImportQuantizedModelEndpoint imports a quantized model into LocalAI.
|
||||
func ImportQuantizedModelEndpoint(qService *services.QuantizationService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
var req schema.QuantizationImportRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "Invalid request: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
modelName, err := qService.ImportModel(c.Request().Context(), userID, jobID, req)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusAccepted, map[string]string{
|
||||
"status": "importing",
|
||||
"message": "Import started for model '" + modelName + "'",
|
||||
"model_name": modelName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// DownloadQuantizedModelEndpoint streams the quantized model file.
|
||||
func DownloadQuantizedModelEndpoint(qService *services.QuantizationService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
outputPath, downloadName, err := qService.GetOutputPath(userID, jobID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.Attachment(outputPath, downloadName)
|
||||
}
|
||||
}
|
||||
|
||||
// ListQuantizationBackendsEndpoint returns installed backends tagged with "quantization".
|
||||
func ListQuantizationBackendsEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": "failed to list backends: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
type backendInfo struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
var result []backendInfo
|
||||
for _, b := range backends {
|
||||
if !b.Installed {
|
||||
continue
|
||||
}
|
||||
hasTag := false
|
||||
for _, t := range b.Tags {
|
||||
if strings.EqualFold(t, "quantization") {
|
||||
hasTag = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasTag {
|
||||
continue
|
||||
}
|
||||
name := b.Name
|
||||
if b.Alias != "" {
|
||||
name = b.Alias
|
||||
}
|
||||
result = append(result, backendInfo{
|
||||
Name: name,
|
||||
Description: b.Description,
|
||||
Tags: b.Tags,
|
||||
})
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
result = []backendInfo{}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,8 @@ const sections = [
|
||||
id: 'tools',
|
||||
title: 'Tools',
|
||||
items: [
|
||||
{ path: '/app/fine-tune', icon: 'fas fa-graduation-cap', label: 'Fine-Tune', feature: 'fine_tuning' },
|
||||
{ path: '/app/fine-tune', icon: 'fas fa-graduation-cap', label: 'Fine-Tune (Experimental)', feature: 'fine_tuning' },
|
||||
{ path: '/app/quantize', icon: 'fas fa-compress', label: 'Quantize (Experimental)', feature: 'quantization' },
|
||||
],
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1042,7 +1042,7 @@ export default function FineTune() {
|
||||
<div className="page">
|
||||
<div className="page-header" style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'flex-start' }}>
|
||||
<div>
|
||||
<h1 className="page-title">Fine-Tuning</h1>
|
||||
<h1 className="page-title">Fine-Tuning <span className="badge badge-warning" style={{ fontSize: '0.45em', verticalAlign: 'middle' }}>Experimental</span></h1>
|
||||
<p className="page-subtitle">Create and manage fine-tuning jobs</p>
|
||||
</div>
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-sm)' }}>
|
||||
|
||||
485
core/http/react-ui/src/pages/Quantize.jsx
Normal file
485
core/http/react-ui/src/pages/Quantize.jsx
Normal file
@@ -0,0 +1,485 @@
|
||||
import { useState, useEffect, useCallback, useRef } from 'react'
|
||||
import { quantizationApi } from '../utils/api'
|
||||
|
||||
const QUANT_PRESETS = [
|
||||
'q2_k', 'q3_k_s', 'q3_k_m', 'q3_k_l',
|
||||
'q4_0', 'q4_k_s', 'q4_k_m',
|
||||
'q5_0', 'q5_k_s', 'q5_k_m',
|
||||
'q6_k', 'q8_0', 'f16',
|
||||
]
|
||||
const DEFAULT_QUANT = 'q4_k_m'
|
||||
const FALLBACK_BACKENDS = ['llama-cpp-quantization']
|
||||
|
||||
const statusBadgeClass = {
|
||||
queued: '', downloading: 'badge-warning', converting: 'badge-warning',
|
||||
quantizing: 'badge-info', completed: 'badge-success',
|
||||
failed: 'badge-error', stopped: '',
|
||||
}
|
||||
|
||||
// ── Reusable sub-components ──────────────────────────────────────
|
||||
|
||||
function FormSection({ icon, title, children }) {
|
||||
return (
|
||||
<div style={{ marginBottom: 'var(--spacing-md)' }}>
|
||||
<h4 style={{ marginBottom: 'var(--spacing-sm)', display: 'flex', alignItems: 'center', gap: '0.5em' }}>
|
||||
{icon && <i className={icon} />} {title}
|
||||
</h4>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function ProgressMonitor({ job, onClose }) {
|
||||
const [events, setEvents] = useState([])
|
||||
const [latestEvent, setLatestEvent] = useState(null)
|
||||
const esRef = useRef(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (!job) return
|
||||
const terminal = ['completed', 'failed', 'stopped']
|
||||
if (terminal.includes(job.status)) return
|
||||
|
||||
const es = new EventSource(quantizationApi.progressUrl(job.id))
|
||||
esRef.current = es
|
||||
es.onmessage = (e) => {
|
||||
try {
|
||||
const data = JSON.parse(e.data)
|
||||
setLatestEvent(data)
|
||||
setEvents(prev => [...prev.slice(-100), data])
|
||||
if (terminal.includes(data.status)) {
|
||||
es.close()
|
||||
}
|
||||
} catch { /* ignore parse errors */ }
|
||||
}
|
||||
es.onerror = () => { es.close() }
|
||||
return () => { es.close() }
|
||||
}, [job?.id])
|
||||
|
||||
if (!job) return null
|
||||
|
||||
const progress = latestEvent?.progress_percent ?? 0
|
||||
const status = latestEvent?.status ?? job.status
|
||||
const message = latestEvent?.message ?? job.message ?? ''
|
||||
|
||||
return (
|
||||
<div className="card" style={{ marginBottom: 'var(--spacing-md)' }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 'var(--spacing-sm)' }}>
|
||||
<h4 style={{ margin: 0 }}>
|
||||
<i className="fas fa-chart-line" style={{ marginRight: '0.5em' }} />
|
||||
Progress: {job.model}
|
||||
</h4>
|
||||
<button className="btn btn-sm" onClick={onClose}><i className="fas fa-times" /></button>
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: 'var(--spacing-sm)' }}>
|
||||
<span className={`badge ${statusBadgeClass[status] || ''}`}>{status}</span>
|
||||
<span style={{ marginLeft: '1em', fontSize: '0.9em', color: 'var(--color-text-secondary)' }}>{message}</span>
|
||||
</div>
|
||||
|
||||
<div style={{
|
||||
width: '100%', height: '24px', borderRadius: '12px',
|
||||
background: 'var(--color-bg-tertiary)', overflow: 'hidden', marginBottom: 'var(--spacing-sm)',
|
||||
}}>
|
||||
<div style={{
|
||||
width: `${Math.min(progress, 100)}%`, height: '100%',
|
||||
background: status === 'failed' ? 'var(--color-error)' : 'var(--color-primary)',
|
||||
transition: 'width 0.3s ease', borderRadius: '12px',
|
||||
display: 'flex', alignItems: 'center', justifyContent: 'center',
|
||||
fontSize: '0.75em', fontWeight: 'bold', color: '#fff',
|
||||
}}>
|
||||
{progress > 8 ? `${progress.toFixed(1)}%` : ''}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Log tail */}
|
||||
<div style={{
|
||||
maxHeight: '150px', overflow: 'auto', fontSize: '0.8em',
|
||||
background: 'var(--color-bg-secondary)', borderRadius: '6px', padding: '0.5em',
|
||||
fontFamily: 'monospace',
|
||||
}}>
|
||||
{events.slice(-20).map((ev, i) => (
|
||||
<div key={i} style={{ color: ev.status === 'failed' ? 'var(--color-error)' : 'var(--color-text-secondary)' }}>
|
||||
[{ev.status}] {ev.message}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function ImportPanel({ job, onRefresh }) {
|
||||
const [modelName, setModelName] = useState('')
|
||||
const [importing, setImporting] = useState(false)
|
||||
const [error, setError] = useState('')
|
||||
const pollRef = useRef(null)
|
||||
|
||||
// Poll for import status
|
||||
useEffect(() => {
|
||||
if (job?.import_status !== 'importing') return
|
||||
pollRef.current = setInterval(async () => {
|
||||
try {
|
||||
await onRefresh()
|
||||
} catch { /* ignore */ }
|
||||
}, 3000)
|
||||
return () => clearInterval(pollRef.current)
|
||||
}, [job?.import_status, onRefresh])
|
||||
|
||||
if (!job || job.status !== 'completed') return null
|
||||
|
||||
const handleImport = async () => {
|
||||
setImporting(true)
|
||||
setError('')
|
||||
try {
|
||||
await quantizationApi.importModel(job.id, { name: modelName || undefined })
|
||||
await onRefresh()
|
||||
} catch (e) {
|
||||
setError(e.message || 'Import failed')
|
||||
} finally {
|
||||
setImporting(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="card" style={{ marginBottom: 'var(--spacing-md)' }}>
|
||||
<h4 style={{ marginBottom: 'var(--spacing-sm)' }}>
|
||||
<i className="fas fa-file-export" style={{ marginRight: '0.5em' }} />
|
||||
Output
|
||||
</h4>
|
||||
|
||||
{error && (
|
||||
<div className="alert alert-error" style={{ marginBottom: 'var(--spacing-sm)' }}>{error}</div>
|
||||
)}
|
||||
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-sm)', flexWrap: 'wrap', alignItems: 'flex-end' }}>
|
||||
{/* Download */}
|
||||
<a
|
||||
href={quantizationApi.downloadUrl(job.id)}
|
||||
className="btn btn-secondary"
|
||||
download
|
||||
>
|
||||
<i className="fas fa-download" style={{ marginRight: '0.4em' }} />
|
||||
Download GGUF
|
||||
</a>
|
||||
|
||||
{/* Import */}
|
||||
{job.import_status === 'completed' ? (
|
||||
<a href={`/app/chat/${encodeURIComponent(job.import_model_name)}`} className="btn btn-success">
|
||||
<i className="fas fa-comments" style={{ marginRight: '0.4em' }} />
|
||||
Chat with {job.import_model_name}
|
||||
</a>
|
||||
) : job.import_status === 'importing' ? (
|
||||
<button className="btn" disabled>
|
||||
<i className="fas fa-spinner fa-spin" style={{ marginRight: '0.4em' }} />
|
||||
Importing... {job.import_message}
|
||||
</button>
|
||||
) : (
|
||||
<>
|
||||
<input
|
||||
className="input"
|
||||
placeholder="Model name (auto-generated if empty)"
|
||||
value={modelName}
|
||||
onChange={e => setModelName(e.target.value)}
|
||||
style={{ maxWidth: '280px' }}
|
||||
/>
|
||||
<button className="btn btn-primary" onClick={handleImport} disabled={importing}>
|
||||
<i className="fas fa-file-import" style={{ marginRight: '0.4em' }} />
|
||||
Import to LocalAI
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{job.import_status === 'failed' && (
|
||||
<div className="alert alert-error" style={{ marginTop: 'var(--spacing-sm)' }}>
|
||||
Import failed: {job.import_message}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ── Main page ────────────────────────────────────────────────────
|
||||
|
||||
export default function Quantize() {
|
||||
// Form state
|
||||
const [model, setModel] = useState('')
|
||||
const [quantType, setQuantType] = useState(DEFAULT_QUANT)
|
||||
const [customQuantType, setCustomQuantType] = useState('')
|
||||
const [useCustomQuant, setUseCustomQuant] = useState(false)
|
||||
const [backend, setBackend] = useState('')
|
||||
const [hfToken, setHfToken] = useState('')
|
||||
const [backends, setBackends] = useState([])
|
||||
|
||||
// Jobs state
|
||||
const [jobs, setJobs] = useState([])
|
||||
const [selectedJob, setSelectedJob] = useState(null)
|
||||
const [error, setError] = useState('')
|
||||
const [submitting, setSubmitting] = useState(false)
|
||||
|
||||
// Load backends and jobs
|
||||
const loadJobs = useCallback(async () => {
|
||||
try {
|
||||
const data = await quantizationApi.listJobs()
|
||||
setJobs(data)
|
||||
// Refresh selected job if it exists
|
||||
if (selectedJob) {
|
||||
const updated = data.find(j => j.id === selectedJob.id)
|
||||
if (updated) setSelectedJob(updated)
|
||||
}
|
||||
} catch { /* ignore */ }
|
||||
}, [selectedJob?.id])
|
||||
|
||||
useEffect(() => {
|
||||
quantizationApi.listBackends().then(b => {
|
||||
setBackends(b.length ? b : FALLBACK_BACKENDS.map(name => ({ name })))
|
||||
if (b.length) setBackend(b[0].name)
|
||||
else setBackend(FALLBACK_BACKENDS[0])
|
||||
}).catch(() => {
|
||||
setBackends(FALLBACK_BACKENDS.map(name => ({ name })))
|
||||
setBackend(FALLBACK_BACKENDS[0])
|
||||
})
|
||||
loadJobs()
|
||||
const interval = setInterval(loadJobs, 10000)
|
||||
return () => clearInterval(interval)
|
||||
}, [])
|
||||
|
||||
const handleSubmit = async (e) => {
|
||||
e.preventDefault()
|
||||
setError('')
|
||||
setSubmitting(true)
|
||||
try {
|
||||
const req = {
|
||||
model,
|
||||
backend: backend || FALLBACK_BACKENDS[0],
|
||||
quantization_type: useCustomQuant ? customQuantType : quantType,
|
||||
}
|
||||
if (hfToken) {
|
||||
req.extra_options = { hf_token: hfToken }
|
||||
}
|
||||
|
||||
const resp = await quantizationApi.startJob(req)
|
||||
setModel('')
|
||||
await loadJobs()
|
||||
|
||||
// Select the new job
|
||||
const refreshed = await quantizationApi.getJob(resp.id)
|
||||
setSelectedJob(refreshed)
|
||||
} catch (err) {
|
||||
setError(err.message || 'Failed to start job')
|
||||
} finally {
|
||||
setSubmitting(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleStop = async (jobId) => {
|
||||
try {
|
||||
await quantizationApi.stopJob(jobId)
|
||||
await loadJobs()
|
||||
} catch (err) {
|
||||
setError(err.message || 'Failed to stop job')
|
||||
}
|
||||
}
|
||||
|
||||
const handleDelete = async (jobId) => {
|
||||
try {
|
||||
await quantizationApi.deleteJob(jobId)
|
||||
if (selectedJob?.id === jobId) setSelectedJob(null)
|
||||
await loadJobs()
|
||||
} catch (err) {
|
||||
setError(err.message || 'Failed to delete job')
|
||||
}
|
||||
}
|
||||
|
||||
const effectiveQuantType = useCustomQuant ? customQuantType : quantType
|
||||
|
||||
return (
|
||||
<div style={{ maxWidth: '900px', margin: '0 auto', padding: 'var(--spacing-md)' }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)', marginBottom: 'var(--spacing-md)' }}>
|
||||
<h2 style={{ margin: 0 }}>
|
||||
<i className="fas fa-compress" style={{ marginRight: '0.4em' }} />
|
||||
Model Quantization
|
||||
</h2>
|
||||
<span className="badge badge-warning" style={{ fontSize: '0.7em' }}>Experimental</span>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="card" style={{ borderColor: 'var(--color-error)', marginBottom: 'var(--spacing-md)' }}>
|
||||
<div style={{ color: 'var(--color-error)' }}><i className="fas fa-exclamation-triangle" /> {error}</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* ── New Job Form ── */}
|
||||
<form onSubmit={handleSubmit}>
|
||||
<div className="card" style={{ marginBottom: 'var(--spacing-md)' }}>
|
||||
<FormSection icon="fas fa-cube" title="Model">
|
||||
<input
|
||||
className="input"
|
||||
placeholder="HuggingFace model name (e.g. meta-llama/Llama-3.2-1B) or local path"
|
||||
value={model}
|
||||
onChange={e => setModel(e.target.value)}
|
||||
required
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
</FormSection>
|
||||
|
||||
<FormSection icon="fas fa-sliders-h" title="Quantization Type">
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-sm)', alignItems: 'center', flexWrap: 'wrap' }}>
|
||||
<select
|
||||
className="input"
|
||||
value={useCustomQuant ? '__custom__' : quantType}
|
||||
onChange={e => {
|
||||
if (e.target.value === '__custom__') {
|
||||
setUseCustomQuant(true)
|
||||
} else {
|
||||
setUseCustomQuant(false)
|
||||
setQuantType(e.target.value)
|
||||
}
|
||||
}}
|
||||
style={{ maxWidth: '200px' }}
|
||||
>
|
||||
{QUANT_PRESETS.map(q => (
|
||||
<option key={q} value={q}>{q}</option>
|
||||
))}
|
||||
<option value="__custom__">Custom...</option>
|
||||
</select>
|
||||
{useCustomQuant && (
|
||||
<input
|
||||
className="input"
|
||||
placeholder="Enter custom quantization type"
|
||||
value={customQuantType}
|
||||
onChange={e => setCustomQuantType(e.target.value)}
|
||||
required
|
||||
style={{ maxWidth: '220px' }}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</FormSection>
|
||||
|
||||
<FormSection icon="fas fa-server" title="Backend">
|
||||
<select
|
||||
className="input"
|
||||
value={backend}
|
||||
onChange={e => setBackend(e.target.value)}
|
||||
style={{ maxWidth: '280px' }}
|
||||
>
|
||||
{backends.map(b => (
|
||||
<option key={b.name || b} value={b.name || b}>{b.name || b}</option>
|
||||
))}
|
||||
</select>
|
||||
</FormSection>
|
||||
|
||||
<FormSection icon="fas fa-key" title="HuggingFace Token (optional)">
|
||||
<input
|
||||
className="input"
|
||||
type="password"
|
||||
placeholder="hf_... (required for gated models)"
|
||||
value={hfToken}
|
||||
onChange={e => setHfToken(e.target.value)}
|
||||
style={{ maxWidth: '400px' }}
|
||||
/>
|
||||
</FormSection>
|
||||
|
||||
<button
|
||||
className="btn btn-primary"
|
||||
type="submit"
|
||||
disabled={submitting || !model || (useCustomQuant && !customQuantType)}
|
||||
>
|
||||
{submitting ? (
|
||||
<><i className="fas fa-spinner fa-spin" style={{ marginRight: '0.4em' }} /> Starting...</>
|
||||
) : (
|
||||
<><i className="fas fa-play" style={{ marginRight: '0.4em' }} /> Quantize ({effectiveQuantType})</>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
{/* ── Selected Job Detail ── */}
|
||||
{selectedJob && (
|
||||
<>
|
||||
<ProgressMonitor
|
||||
job={selectedJob}
|
||||
onClose={() => setSelectedJob(null)}
|
||||
/>
|
||||
<ImportPanel
|
||||
job={selectedJob}
|
||||
onRefresh={async () => {
|
||||
const updated = await quantizationApi.getJob(selectedJob.id)
|
||||
setSelectedJob(updated)
|
||||
await loadJobs()
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* ── Jobs List ── */}
|
||||
{jobs.length > 0 && (
|
||||
<div className="card">
|
||||
<h4 style={{ marginBottom: 'var(--spacing-sm)' }}>
|
||||
<i className="fas fa-list" style={{ marginRight: '0.5em' }} />
|
||||
Jobs
|
||||
</h4>
|
||||
<div style={{ overflowX: 'auto' }}>
|
||||
<table style={{ width: '100%', borderCollapse: 'collapse', fontSize: '0.9em' }}>
|
||||
<thead>
|
||||
<tr style={{ borderBottom: '1px solid var(--color-border)' }}>
|
||||
<th style={{ textAlign: 'left', padding: '0.5em' }}>Model</th>
|
||||
<th style={{ textAlign: 'left', padding: '0.5em' }}>Quant</th>
|
||||
<th style={{ textAlign: 'left', padding: '0.5em' }}>Status</th>
|
||||
<th style={{ textAlign: 'left', padding: '0.5em' }}>Created</th>
|
||||
<th style={{ textAlign: 'right', padding: '0.5em' }}>Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{jobs.map(job => {
|
||||
const isActive = ['queued', 'downloading', 'converting', 'quantizing'].includes(job.status)
|
||||
const isSelected = selectedJob?.id === job.id
|
||||
return (
|
||||
<tr
|
||||
key={job.id}
|
||||
style={{
|
||||
borderBottom: '1px solid var(--color-border)',
|
||||
background: isSelected ? 'var(--color-bg-secondary)' : undefined,
|
||||
cursor: 'pointer',
|
||||
}}
|
||||
onClick={() => setSelectedJob(job)}
|
||||
>
|
||||
<td style={{ padding: '0.5em', maxWidth: '250px', overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' }}>
|
||||
{job.model}
|
||||
</td>
|
||||
<td style={{ padding: '0.5em' }}>
|
||||
<code>{job.quantization_type}</code>
|
||||
</td>
|
||||
<td style={{ padding: '0.5em' }}>
|
||||
<span className={`badge ${statusBadgeClass[job.status] || ''}`}>{job.status}</span>
|
||||
{job.import_status === 'completed' && (
|
||||
<span className="badge badge-success" style={{ marginLeft: '0.3em' }}>imported</span>
|
||||
)}
|
||||
</td>
|
||||
<td style={{ padding: '0.5em', fontSize: '0.85em', color: 'var(--color-text-secondary)' }}>
|
||||
{new Date(job.created_at).toLocaleString()}
|
||||
</td>
|
||||
<td style={{ padding: '0.5em', textAlign: 'right' }}>
|
||||
<div style={{ display: 'flex', gap: '0.3em', justifyContent: 'flex-end' }} onClick={e => e.stopPropagation()}>
|
||||
{isActive && (
|
||||
<button className="btn btn-sm btn-error" onClick={() => handleStop(job.id)} title="Stop">
|
||||
<i className="fas fa-stop" />
|
||||
</button>
|
||||
)}
|
||||
{!isActive && (
|
||||
<button className="btn btn-sm" onClick={() => handleDelete(job.id)} title="Delete">
|
||||
<i className="fas fa-trash" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
)
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -32,6 +32,7 @@ import BackendLogs from './pages/BackendLogs'
|
||||
import Explorer from './pages/Explorer'
|
||||
import Login from './pages/Login'
|
||||
import FineTune from './pages/FineTune'
|
||||
import Quantize from './pages/Quantize'
|
||||
import Studio from './pages/Studio'
|
||||
import NotFound from './pages/NotFound'
|
||||
import Usage from './pages/Usage'
|
||||
@@ -95,6 +96,7 @@ const appChildren = [
|
||||
{ path: 'agent-jobs/tasks/:id/edit', element: <Feature feature="mcp_jobs"><AgentTaskDetails /></Feature> },
|
||||
{ path: 'agent-jobs/jobs/:id', element: <Feature feature="mcp_jobs"><AgentJobDetails /></Feature> },
|
||||
{ path: 'fine-tune', element: <Feature feature="fine_tuning"><FineTune /></Feature> },
|
||||
{ path: 'quantize', element: <Feature feature="quantization"><Quantize /></Feature> },
|
||||
{ path: 'model-editor/:name', element: <Admin><ModelEditor /></Admin> },
|
||||
{ path: 'pipeline-editor', element: <Admin><PipelineEditor /></Admin> },
|
||||
{ path: 'pipeline-editor/:name', element: <Admin><PipelineEditor /></Admin> },
|
||||
|
||||
13
core/http/react-ui/src/utils/api.js
vendored
13
core/http/react-ui/src/utils/api.js
vendored
@@ -410,6 +410,19 @@ export const fineTuneApi = {
|
||||
downloadUrl: (id) => apiUrl(`/api/fine-tuning/jobs/${enc(id)}/download`),
|
||||
}
|
||||
|
||||
// Quantization API
|
||||
export const quantizationApi = {
|
||||
listBackends: () => fetchJSON('/api/quantization/backends'),
|
||||
startJob: (data) => postJSON('/api/quantization/jobs', data),
|
||||
listJobs: () => fetchJSON('/api/quantization/jobs'),
|
||||
getJob: (id) => fetchJSON(`/api/quantization/jobs/${enc(id)}`),
|
||||
stopJob: (id) => fetchJSON(`/api/quantization/jobs/${enc(id)}/stop`, { method: 'POST' }),
|
||||
deleteJob: (id) => fetchJSON(`/api/quantization/jobs/${enc(id)}`, { method: 'DELETE' }),
|
||||
importModel: (id, data) => postJSON(`/api/quantization/jobs/${enc(id)}/import`, data),
|
||||
progressUrl: (id) => apiUrl(`/api/quantization/jobs/${enc(id)}/progress`),
|
||||
downloadUrl: (id) => apiUrl(`/api/quantization/jobs/${enc(id)}/download`),
|
||||
}
|
||||
|
||||
// File to base64 helper
|
||||
export function fileToBase64(file) {
|
||||
return new Promise((resolve, reject) => {
|
||||
|
||||
@@ -134,9 +134,10 @@ func RegisterLocalAIRoutes(router *echo.Echo,
|
||||
|
||||
router.GET("/api/features", func(c echo.Context) error {
|
||||
return c.JSON(200, map[string]bool{
|
||||
"agents": appConfig.AgentPool.Enabled,
|
||||
"mcp": !appConfig.DisableMCP,
|
||||
"fine_tuning": appConfig.FineTuning.Enabled,
|
||||
"agents": appConfig.AgentPool.Enabled,
|
||||
"mcp": !appConfig.DisableMCP,
|
||||
"fine_tuning": true,
|
||||
"quantization": true,
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
40
core/http/routes/quantization.go
Normal file
40
core/http/routes/quantization.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// RegisterQuantizationRoutes registers quantization API routes.
|
||||
func RegisterQuantizationRoutes(e *echo.Echo, qService *services.QuantizationService, appConfig *config.ApplicationConfig, quantizationMw echo.MiddlewareFunc) {
|
||||
if qService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Service readiness middleware
|
||||
readyMw := func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if qService == nil {
|
||||
return c.JSON(http.StatusServiceUnavailable, map[string]string{
|
||||
"error": "quantization service is not available",
|
||||
})
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
|
||||
q := e.Group("/api/quantization", readyMw, quantizationMw)
|
||||
q.GET("/backends", localai.ListQuantizationBackendsEndpoint(appConfig))
|
||||
q.POST("/jobs", localai.StartQuantizationJobEndpoint(qService))
|
||||
q.GET("/jobs", localai.ListQuantizationJobsEndpoint(qService))
|
||||
q.GET("/jobs/:id", localai.GetQuantizationJobEndpoint(qService))
|
||||
q.POST("/jobs/:id/stop", localai.StopQuantizationJobEndpoint(qService))
|
||||
q.DELETE("/jobs/:id", localai.DeleteQuantizationJobEndpoint(qService))
|
||||
q.GET("/jobs/:id/progress", localai.QuantizationProgressEndpoint(qService))
|
||||
q.POST("/jobs/:id/import", localai.ImportQuantizedModelEndpoint(qService))
|
||||
q.GET("/jobs/:id/download", localai.DownloadQuantizedModelEndpoint(qService))
|
||||
}
|
||||
55
core/schema/quantization.go
Normal file
55
core/schema/quantization.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package schema
|
||||
|
||||
// QuantizationJobRequest is the REST API request to start a quantization job.
|
||||
type QuantizationJobRequest struct {
|
||||
Model string `json:"model"` // HF model name or local path
|
||||
Backend string `json:"backend"` // "llama-cpp-quantization"
|
||||
QuantizationType string `json:"quantization_type,omitempty"` // q4_k_m, q5_k_m, q8_0, f16, etc.
|
||||
ExtraOptions map[string]string `json:"extra_options,omitempty"`
|
||||
}
|
||||
|
||||
// QuantizationJob represents a quantization job with its current state.
|
||||
type QuantizationJob struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Backend string `json:"backend"`
|
||||
ModelID string `json:"model_id,omitempty"`
|
||||
QuantizationType string `json:"quantization_type"`
|
||||
Status string `json:"status"` // queued, downloading, converting, quantizing, completed, failed, stopped
|
||||
Message string `json:"message,omitempty"`
|
||||
OutputDir string `json:"output_dir"`
|
||||
OutputFile string `json:"output_file,omitempty"` // path to final GGUF
|
||||
ExtraOptions map[string]string `json:"extra_options,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
|
||||
// Import state (tracked separately from quantization status)
|
||||
ImportStatus string `json:"import_status,omitempty"` // "", "importing", "completed", "failed"
|
||||
ImportMessage string `json:"import_message,omitempty"`
|
||||
ImportModelName string `json:"import_model_name,omitempty"` // registered model name after import
|
||||
|
||||
// Full config for reuse
|
||||
Config *QuantizationJobRequest `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
// QuantizationJobResponse is the REST API response when creating a job.
|
||||
type QuantizationJobResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// QuantizationProgressEvent is an SSE event for quantization progress.
|
||||
type QuantizationProgressEvent struct {
|
||||
JobID string `json:"job_id"`
|
||||
ProgressPercent float32 `json:"progress_percent"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
OutputFile string `json:"output_file,omitempty"`
|
||||
ExtraMetrics map[string]float32 `json:"extra_metrics,omitempty"`
|
||||
}
|
||||
|
||||
// QuantizationImportRequest is the REST API request to import a quantized model.
|
||||
type QuantizationImportRequest struct {
|
||||
Name string `json:"name,omitempty"` // model name for LocalAI (auto-generated if empty)
|
||||
}
|
||||
574
core/services/quantization.go
Normal file
574
core/services/quantization.go
Normal file
@@ -0,0 +1,574 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// QuantizationService manages quantization jobs and their lifecycle.
|
||||
type QuantizationService struct {
|
||||
appConfig *config.ApplicationConfig
|
||||
modelLoader *model.ModelLoader
|
||||
configLoader *config.ModelConfigLoader
|
||||
|
||||
mu sync.Mutex
|
||||
jobs map[string]*schema.QuantizationJob
|
||||
}
|
||||
|
||||
// NewQuantizationService creates a new QuantizationService.
|
||||
func NewQuantizationService(
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelLoader *model.ModelLoader,
|
||||
configLoader *config.ModelConfigLoader,
|
||||
) *QuantizationService {
|
||||
s := &QuantizationService{
|
||||
appConfig: appConfig,
|
||||
modelLoader: modelLoader,
|
||||
configLoader: configLoader,
|
||||
jobs: make(map[string]*schema.QuantizationJob),
|
||||
}
|
||||
s.loadAllJobs()
|
||||
return s
|
||||
}
|
||||
|
||||
// quantizationBaseDir returns the base directory for quantization job data.
|
||||
func (s *QuantizationService) quantizationBaseDir() string {
|
||||
return filepath.Join(s.appConfig.DataPath, "quantization")
|
||||
}
|
||||
|
||||
// jobDir returns the directory for a specific job.
|
||||
func (s *QuantizationService) jobDir(jobID string) string {
|
||||
return filepath.Join(s.quantizationBaseDir(), jobID)
|
||||
}
|
||||
|
||||
// saveJobState persists a job's state to disk as state.json.
|
||||
func (s *QuantizationService) saveJobState(job *schema.QuantizationJob) {
|
||||
dir := s.jobDir(job.ID)
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
xlog.Error("Failed to create quantization job directory", "job_id", job.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(job, "", " ")
|
||||
if err != nil {
|
||||
xlog.Error("Failed to marshal quantization job state", "job_id", job.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
statePath := filepath.Join(dir, "state.json")
|
||||
if err := os.WriteFile(statePath, data, 0640); err != nil {
|
||||
xlog.Error("Failed to write quantization job state", "job_id", job.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// loadAllJobs scans the quantization directory for persisted jobs and loads them.
|
||||
func (s *QuantizationService) loadAllJobs() {
|
||||
baseDir := s.quantizationBaseDir()
|
||||
entries, err := os.ReadDir(baseDir)
|
||||
if err != nil {
|
||||
// Directory doesn't exist yet — that's fine
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
statePath := filepath.Join(baseDir, entry.Name(), "state.json")
|
||||
data, err := os.ReadFile(statePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var job schema.QuantizationJob
|
||||
if err := json.Unmarshal(data, &job); err != nil {
|
||||
xlog.Warn("Failed to parse quantization job state", "path", statePath, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Jobs that were running when we shut down are now stale
|
||||
if job.Status == "queued" || job.Status == "downloading" || job.Status == "converting" || job.Status == "quantizing" {
|
||||
job.Status = "stopped"
|
||||
job.Message = "Server restarted while job was running"
|
||||
}
|
||||
|
||||
// Imports that were in progress are now stale
|
||||
if job.ImportStatus == "importing" {
|
||||
job.ImportStatus = "failed"
|
||||
job.ImportMessage = "Server restarted while import was running"
|
||||
}
|
||||
|
||||
s.jobs[job.ID] = &job
|
||||
}
|
||||
|
||||
if len(s.jobs) > 0 {
|
||||
xlog.Info("Loaded persisted quantization jobs", "count", len(s.jobs))
|
||||
}
|
||||
}
|
||||
|
||||
// StartJob starts a new quantization job.
|
||||
func (s *QuantizationService) StartJob(ctx context.Context, userID string, req schema.QuantizationJobRequest) (*schema.QuantizationJobResponse, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
jobID := uuid.New().String()
|
||||
|
||||
backendName := req.Backend
|
||||
if backendName == "" {
|
||||
backendName = "llama-cpp-quantization"
|
||||
}
|
||||
|
||||
quantType := req.QuantizationType
|
||||
if quantType == "" {
|
||||
quantType = "q4_k_m"
|
||||
}
|
||||
|
||||
// Always use DataPath for output — not user-configurable
|
||||
outputDir := filepath.Join(s.quantizationBaseDir(), jobID)
|
||||
|
||||
// Build gRPC request
|
||||
grpcReq := &pb.QuantizationRequest{
|
||||
Model: req.Model,
|
||||
QuantizationType: quantType,
|
||||
OutputDir: outputDir,
|
||||
JobId: jobID,
|
||||
ExtraOptions: req.ExtraOptions,
|
||||
}
|
||||
|
||||
// Load the quantization backend (per-job model ID so multiple jobs can run concurrently)
|
||||
modelID := backendName + "-quantize-" + jobID
|
||||
backendModel, err := s.modelLoader.Load(
|
||||
model.WithBackendString(backendName),
|
||||
model.WithModel(backendName),
|
||||
model.WithModelID(modelID),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load backend %s: %w", backendName, err)
|
||||
}
|
||||
|
||||
// Start quantization via gRPC
|
||||
result, err := backendModel.StartQuantization(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start quantization: %w", err)
|
||||
}
|
||||
if !result.Success {
|
||||
return nil, fmt.Errorf("quantization failed to start: %s", result.Message)
|
||||
}
|
||||
|
||||
// Track the job
|
||||
job := &schema.QuantizationJob{
|
||||
ID: jobID,
|
||||
UserID: userID,
|
||||
Model: req.Model,
|
||||
Backend: backendName,
|
||||
ModelID: modelID,
|
||||
QuantizationType: quantType,
|
||||
Status: "queued",
|
||||
OutputDir: outputDir,
|
||||
ExtraOptions: req.ExtraOptions,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
Config: &req,
|
||||
}
|
||||
s.jobs[jobID] = job
|
||||
s.saveJobState(job)
|
||||
|
||||
return &schema.QuantizationJobResponse{
|
||||
ID: jobID,
|
||||
Status: "queued",
|
||||
Message: result.Message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetJob returns a quantization job by ID.
|
||||
func (s *QuantizationService) GetJob(userID, jobID string) (*schema.QuantizationJob, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// ListJobs returns all jobs for a user, sorted by creation time (newest first).
|
||||
func (s *QuantizationService) ListJobs(userID string) []*schema.QuantizationJob {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var result []*schema.QuantizationJob
|
||||
for _, job := range s.jobs {
|
||||
if userID == "" || job.UserID == userID {
|
||||
result = append(result, job)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].CreatedAt > result[j].CreatedAt
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// StopJob stops a running quantization job.
|
||||
func (s *QuantizationService) StopJob(ctx context.Context, userID, jobID string) error {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Kill the backend process directly
|
||||
stopModelID := job.ModelID
|
||||
if stopModelID == "" {
|
||||
stopModelID = job.Backend + "-quantize"
|
||||
}
|
||||
s.modelLoader.ShutdownModel(stopModelID)
|
||||
|
||||
s.mu.Lock()
|
||||
job.Status = "stopped"
|
||||
job.Message = "Quantization stopped by user"
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteJob removes a quantization job and its associated data from disk.
|
||||
func (s *QuantizationService) DeleteJob(userID, jobID string) error {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
|
||||
// Reject deletion of actively running jobs
|
||||
activeStatuses := map[string]bool{
|
||||
"queued": true, "downloading": true, "converting": true, "quantizing": true,
|
||||
}
|
||||
if activeStatuses[job.Status] {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("cannot delete job %s: currently %s (stop it first)", jobID, job.Status)
|
||||
}
|
||||
if job.ImportStatus == "importing" {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("cannot delete job %s: import in progress", jobID)
|
||||
}
|
||||
|
||||
importModelName := job.ImportModelName
|
||||
delete(s.jobs, jobID)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Remove job directory (state.json, output files)
|
||||
jobDir := s.jobDir(jobID)
|
||||
if err := os.RemoveAll(jobDir); err != nil {
|
||||
xlog.Warn("Failed to remove quantization job directory", "job_id", jobID, "path", jobDir, "error", err)
|
||||
}
|
||||
|
||||
// If an imported model exists, clean it up too
|
||||
if importModelName != "" {
|
||||
modelsPath := s.appConfig.SystemState.Model.ModelsPath
|
||||
modelDir := filepath.Join(modelsPath, importModelName)
|
||||
configPath := filepath.Join(modelsPath, importModelName+".yaml")
|
||||
|
||||
if err := os.RemoveAll(modelDir); err != nil {
|
||||
xlog.Warn("Failed to remove imported model directory", "path", modelDir, "error", err)
|
||||
}
|
||||
if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) {
|
||||
xlog.Warn("Failed to remove imported model config", "path", configPath, "error", err)
|
||||
}
|
||||
|
||||
// Reload model configs
|
||||
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
xlog.Warn("Failed to reload configs after delete", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
xlog.Info("Deleted quantization job", "job_id", jobID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StreamProgress opens a gRPC progress stream and calls the callback for each update.
|
||||
func (s *QuantizationService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.QuantizationProgressEvent)) error {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
streamModelID := job.ModelID
|
||||
if streamModelID == "" {
|
||||
streamModelID = job.Backend + "-quantize"
|
||||
}
|
||||
backendModel, err := s.modelLoader.Load(
|
||||
model.WithBackendString(job.Backend),
|
||||
model.WithModel(job.Backend),
|
||||
model.WithModelID(streamModelID),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load backend: %w", err)
|
||||
}
|
||||
|
||||
return backendModel.QuantizationProgress(ctx, &pb.QuantizationProgressRequest{
|
||||
JobId: jobID,
|
||||
}, func(update *pb.QuantizationProgressUpdate) {
|
||||
// Update job status and persist
|
||||
s.mu.Lock()
|
||||
if j, ok := s.jobs[jobID]; ok {
|
||||
// Don't let progress updates overwrite terminal states
|
||||
isTerminal := j.Status == "stopped" || j.Status == "completed" || j.Status == "failed"
|
||||
if !isTerminal {
|
||||
j.Status = update.Status
|
||||
}
|
||||
if update.Message != "" {
|
||||
j.Message = update.Message
|
||||
}
|
||||
if update.OutputFile != "" {
|
||||
j.OutputFile = update.OutputFile
|
||||
}
|
||||
s.saveJobState(j)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Convert extra metrics
|
||||
extraMetrics := make(map[string]float32)
|
||||
for k, v := range update.ExtraMetrics {
|
||||
extraMetrics[k] = v
|
||||
}
|
||||
|
||||
event := &schema.QuantizationProgressEvent{
|
||||
JobID: update.JobId,
|
||||
ProgressPercent: update.ProgressPercent,
|
||||
Status: update.Status,
|
||||
Message: update.Message,
|
||||
OutputFile: update.OutputFile,
|
||||
ExtraMetrics: extraMetrics,
|
||||
}
|
||||
callback(event)
|
||||
})
|
||||
}
|
||||
|
||||
// sanitizeQuantModelName replaces non-alphanumeric characters with hyphens and lowercases.
|
||||
func sanitizeQuantModelName(s string) string {
|
||||
re := regexp.MustCompile(`[^a-zA-Z0-9\-]`)
|
||||
s = re.ReplaceAllString(s, "-")
|
||||
s = regexp.MustCompile(`-+`).ReplaceAllString(s, "-")
|
||||
s = strings.Trim(s, "-")
|
||||
return strings.ToLower(s)
|
||||
}
|
||||
|
||||
// ImportModel imports a quantized model into LocalAI asynchronously.
|
||||
func (s *QuantizationService) ImportModel(ctx context.Context, userID, jobID string, req schema.QuantizationImportRequest) (string, error) {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if job.Status != "completed" {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("job %s is not completed (status: %s)", jobID, job.Status)
|
||||
}
|
||||
if job.ImportStatus == "importing" {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("import already in progress for job %s", jobID)
|
||||
}
|
||||
if job.OutputFile == "" {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("no output file for job %s", jobID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Compute model name
|
||||
modelName := req.Name
|
||||
if modelName == "" {
|
||||
base := sanitizeQuantModelName(job.Model)
|
||||
if base == "" {
|
||||
base = "model"
|
||||
}
|
||||
shortID := jobID
|
||||
if len(shortID) > 8 {
|
||||
shortID = shortID[:8]
|
||||
}
|
||||
modelName = base + "-" + job.QuantizationType + "-" + shortID
|
||||
}
|
||||
|
||||
// Compute output path in models directory
|
||||
modelsPath := s.appConfig.SystemState.Model.ModelsPath
|
||||
outputPath := filepath.Join(modelsPath, modelName)
|
||||
|
||||
// Check for name collision
|
||||
configPath := filepath.Join(modelsPath, modelName+".yaml")
|
||||
if err := utils.VerifyPath(modelName+".yaml", modelsPath); err != nil {
|
||||
return "", fmt.Errorf("invalid model name: %w", err)
|
||||
}
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
return "", fmt.Errorf("model %q already exists, choose a different name", modelName)
|
||||
}
|
||||
|
||||
// Create output directory
|
||||
if err := os.MkdirAll(outputPath, 0750); err != nil {
|
||||
return "", fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
|
||||
// Set import status
|
||||
s.mu.Lock()
|
||||
job.ImportStatus = "importing"
|
||||
job.ImportMessage = ""
|
||||
job.ImportModelName = ""
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Launch the import in a background goroutine
|
||||
go func() {
|
||||
s.setImportMessage(job, "Copying quantized model...")
|
||||
|
||||
// Copy the output file to the models directory
|
||||
srcFile := job.OutputFile
|
||||
dstFile := filepath.Join(outputPath, filepath.Base(srcFile))
|
||||
|
||||
srcData, err := os.ReadFile(srcFile)
|
||||
if err != nil {
|
||||
s.setImportFailed(job, fmt.Sprintf("failed to read output file: %v", err))
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile(dstFile, srcData, 0644); err != nil {
|
||||
s.setImportFailed(job, fmt.Sprintf("failed to write model file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.setImportMessage(job, "Generating model configuration...")
|
||||
|
||||
// Auto-import: detect format and generate config
|
||||
cfg, err := importers.ImportLocalPath(outputPath, modelName)
|
||||
if err != nil {
|
||||
s.setImportFailed(job, fmt.Sprintf("model copied to %s but config generation failed: %v", outputPath, err))
|
||||
return
|
||||
}
|
||||
|
||||
cfg.Name = modelName
|
||||
|
||||
// Write YAML config
|
||||
yamlData, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
s.setImportFailed(job, fmt.Sprintf("failed to marshal config: %v", err))
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
|
||||
s.setImportFailed(job, fmt.Sprintf("failed to write config file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.setImportMessage(job, "Registering model with LocalAI...")
|
||||
|
||||
// Reload configs so the model is immediately available
|
||||
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
xlog.Warn("Failed to reload configs after import", "error", err)
|
||||
}
|
||||
if err := s.configLoader.Preload(modelsPath); err != nil {
|
||||
xlog.Warn("Failed to preload after import", "error", err)
|
||||
}
|
||||
|
||||
xlog.Info("Quantized model imported and registered", "job_id", jobID, "model_name", modelName)
|
||||
|
||||
s.mu.Lock()
|
||||
job.ImportStatus = "completed"
|
||||
job.ImportModelName = modelName
|
||||
job.ImportMessage = ""
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
return modelName, nil
|
||||
}
|
||||
|
||||
// setImportMessage updates the import message and persists the job state.
|
||||
func (s *QuantizationService) setImportMessage(job *schema.QuantizationJob, msg string) {
|
||||
s.mu.Lock()
|
||||
job.ImportMessage = msg
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// setImportFailed sets the import status to failed with a message.
|
||||
func (s *QuantizationService) setImportFailed(job *schema.QuantizationJob, message string) {
|
||||
xlog.Error("Quantization import failed", "job_id", job.ID, "error", message)
|
||||
s.mu.Lock()
|
||||
job.ImportStatus = "failed"
|
||||
job.ImportMessage = message
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetOutputPath returns the path to the quantized model file and a download name.
|
||||
func (s *QuantizationService) GetOutputPath(userID, jobID string) (string, string, error) {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return "", "", fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return "", "", fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if job.Status != "completed" {
|
||||
s.mu.Unlock()
|
||||
return "", "", fmt.Errorf("job not completed (status: %s)", job.Status)
|
||||
}
|
||||
if job.OutputFile == "" {
|
||||
s.mu.Unlock()
|
||||
return "", "", fmt.Errorf("no output file for job %s", jobID)
|
||||
}
|
||||
outputFile := job.OutputFile
|
||||
s.mu.Unlock()
|
||||
|
||||
if _, err := os.Stat(outputFile); os.IsNotExist(err) {
|
||||
return "", "", fmt.Errorf("output file not found: %s", outputFile)
|
||||
}
|
||||
|
||||
downloadName := filepath.Base(outputFile)
|
||||
return outputFile, downloadName, nil
|
||||
}
|
||||
@@ -13,15 +13,13 @@ LocalAI supports fine-tuning LLMs directly through the API and Web UI. Fine-tuni
|
||||
|---------|--------|-------------|-----------------|---------------|
|
||||
| **trl** | LLM fine-tuning | No (CPU or GPU) | SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO | LoRA, Full |
|
||||
|
||||
## Enabling Fine-Tuning
|
||||
## Availability
|
||||
|
||||
Fine-tuning is disabled by default. Enable it with:
|
||||
Fine-tuning is always enabled. When authentication is enabled, fine-tuning is a per-user feature (default OFF). Admins can enable it for specific users via the user management API.
|
||||
|
||||
```bash
|
||||
LOCALAI_ENABLE_FINETUNING=true local-ai
|
||||
```
|
||||
|
||||
When authentication is enabled, fine-tuning is a per-user feature (default OFF). Admins can enable it for specific users via the user management API.
|
||||
{{% alert note %}}
|
||||
This feature is **experimental** and may change in future releases.
|
||||
{{% /alert %}}
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
||||
158
docs/content/features/quantization.md
Normal file
158
docs/content/features/quantization.md
Normal file
@@ -0,0 +1,158 @@
|
||||
+++
|
||||
disableToc = false
|
||||
title = "Model Quantization"
|
||||
weight = 19
|
||||
url = '/features/quantization/'
|
||||
+++
|
||||
|
||||
LocalAI supports model quantization directly through the API and Web UI. Quantization converts HuggingFace models to GGUF format and compresses them to smaller sizes for efficient inference with llama.cpp.
|
||||
|
||||
{{% alert note %}}
|
||||
This feature is **experimental** and may change in future releases.
|
||||
{{% /alert %}}
|
||||
|
||||
## Supported Backends
|
||||
|
||||
| Backend | Description | Quantization Types | Platforms |
|
||||
|---------|-------------|-------------------|-----------|
|
||||
| **llama-cpp-quantization** | Downloads HF models, converts to GGUF, and quantizes using llama.cpp | q2_k, q3_k_s, q3_k_m, q3_k_l, q4_0, q4_k_s, q4_k_m, q5_0, q5_k_s, q5_k_m, q6_k, q8_0, f16 | CPU (Linux, macOS) |
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Start a quantization job
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/quantization/jobs \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "unsloth/functiongemma-270m-it",
|
||||
"quantization_type": "q4_k_m"
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. Monitor progress (SSE stream)
|
||||
|
||||
```bash
|
||||
curl -N http://localhost:8080/api/quantization/jobs/{job_id}/progress
|
||||
```
|
||||
|
||||
### 3. Download the quantized model
|
||||
|
||||
```bash
|
||||
curl -o model.gguf http://localhost:8080/api/quantization/jobs/{job_id}/download
|
||||
```
|
||||
|
||||
### 4. Or import it directly into LocalAI
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/quantization/jobs/{job_id}/import \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"name": "my-quantized-model"
|
||||
}'
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| `POST` | `/api/quantization/jobs` | Start a quantization job |
|
||||
| `GET` | `/api/quantization/jobs` | List all jobs |
|
||||
| `GET` | `/api/quantization/jobs/:id` | Get job details |
|
||||
| `POST` | `/api/quantization/jobs/:id/stop` | Stop a running job |
|
||||
| `DELETE` | `/api/quantization/jobs/:id` | Delete a job and its data |
|
||||
| `GET` | `/api/quantization/jobs/:id/progress` | SSE progress stream |
|
||||
| `POST` | `/api/quantization/jobs/:id/import` | Import quantized model into LocalAI |
|
||||
| `GET` | `/api/quantization/jobs/:id/download` | Download quantized GGUF file |
|
||||
| `GET` | `/api/quantization/backends` | List available quantization backends |
|
||||
|
||||
### Job Request Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `model` | string | HuggingFace model ID or local path (required) |
|
||||
| `backend` | string | Backend name (default: `llama-cpp-quantization`) |
|
||||
| `quantization_type` | string | Quantization format (default: `q4_k_m`) |
|
||||
| `extra_options` | map | Backend-specific options (see below) |
|
||||
|
||||
### Extra Options
|
||||
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `hf_token` | HuggingFace token for gated models |
|
||||
|
||||
### Import Request Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `name` | string | Model name for LocalAI (auto-generated if empty) |
|
||||
|
||||
### Job Status Values
|
||||
|
||||
| Status | Description |
|
||||
|--------|-------------|
|
||||
| `queued` | Job created, waiting to start |
|
||||
| `downloading` | Downloading model from HuggingFace |
|
||||
| `converting` | Converting model to f16 GGUF |
|
||||
| `quantizing` | Running quantization |
|
||||
| `completed` | Quantization finished successfully |
|
||||
| `failed` | Job failed (check message for details) |
|
||||
| `stopped` | Job stopped by user |
|
||||
|
||||
### Progress Stream
|
||||
|
||||
The `GET /api/quantization/jobs/:id/progress` endpoint returns Server-Sent Events (SSE) with JSON payloads:
|
||||
|
||||
```json
|
||||
{
|
||||
"job_id": "abc-123",
|
||||
"progress_percent": 65.0,
|
||||
"status": "quantizing",
|
||||
"message": "[ 234/ 567] quantizing blk.5.attn_k.weight ...",
|
||||
"output_file": "",
|
||||
"extra_metrics": {}
|
||||
}
|
||||
```
|
||||
|
||||
When the job completes, `output_file` contains the path to the quantized GGUF file and `extra_metrics` includes `file_size_mb`.
|
||||
|
||||
## Quantization Types
|
||||
|
||||
| Type | Size | Quality | Description |
|
||||
|------|------|---------|-------------|
|
||||
| `q2_k` | Smallest | Lowest | 2-bit quantization |
|
||||
| `q3_k_s` | Very small | Low | 3-bit small |
|
||||
| `q3_k_m` | Very small | Low | 3-bit medium |
|
||||
| `q3_k_l` | Small | Low-medium | 3-bit large |
|
||||
| `q4_0` | Small | Medium | 4-bit legacy |
|
||||
| `q4_k_s` | Small | Medium | 4-bit small |
|
||||
| `q4_k_m` | Small | **Good** | **4-bit medium (recommended)** |
|
||||
| `q5_0` | Medium | Good | 5-bit legacy |
|
||||
| `q5_k_s` | Medium | Good | 5-bit small |
|
||||
| `q5_k_m` | Medium | Very good | 5-bit medium |
|
||||
| `q6_k` | Large | Excellent | 6-bit |
|
||||
| `q8_0` | Large | Near-lossless | 8-bit |
|
||||
| `f16` | Largest | Lossless | 16-bit (no quantization, GGUF conversion only) |
|
||||
|
||||
The UI also supports entering a custom quantization type string for any format supported by `llama-quantize`.
|
||||
|
||||
## Web UI
|
||||
|
||||
A "Quantize" page appears in the sidebar under the Tools section. The UI provides:
|
||||
|
||||
1. **Job Configuration** — Select model, quantization type (dropdown with presets or custom input), backend, and HuggingFace token
|
||||
2. **Progress Monitor** — Real-time progress bar and log output via SSE
|
||||
3. **Jobs List** — View all quantization jobs with status, stop/delete actions
|
||||
4. **Output** — Download the quantized GGUF file or import it directly into LocalAI for immediate use
|
||||
|
||||
## Architecture
|
||||
|
||||
Quantization uses the same gRPC backend architecture as fine-tuning:
|
||||
|
||||
1. **Proto layer**: `QuantizationRequest`, `QuantizationProgress` (streaming), `StopQuantization`
|
||||
2. **Python backend**: Downloads model, runs `convert_hf_to_gguf.py` and `llama-quantize`
|
||||
3. **Go service**: Manages job lifecycle, state persistence, async import
|
||||
4. **REST API**: HTTP endpoints with SSE progress streaming
|
||||
5. **React UI**: Configuration form, real-time progress monitor, download/import panel
|
||||
@@ -70,4 +70,9 @@ type Backend interface {
|
||||
StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error)
|
||||
ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
|
||||
// Quantization
|
||||
StartQuantization(ctx context.Context, in *pb.QuantizationRequest, opts ...grpc.CallOption) (*pb.QuantizationJobResult, error)
|
||||
QuantizationProgress(ctx context.Context, in *pb.QuantizationProgressRequest, f func(update *pb.QuantizationProgressUpdate), opts ...grpc.CallOption) error
|
||||
StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
}
|
||||
|
||||
@@ -140,6 +140,18 @@ func (llm *Base) ExportModel(*pb.ExportModelRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) StartQuantization(*pb.QuantizationRequest) (*pb.QuantizationJobResult, error) {
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) QuantizationProgress(*pb.QuantizationProgressRequest, chan *pb.QuantizationProgressUpdate) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) StopQuantization(*pb.QuantizationStopRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func memoryUsage() *pb.MemoryUsageData {
|
||||
mud := pb.MemoryUsageData{
|
||||
Breakdown: make(map[string]uint64),
|
||||
|
||||
@@ -768,6 +768,98 @@ func (c *Client) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opt
|
||||
return client.ExportModel(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) StartQuantization(ctx context.Context, in *pb.QuantizationRequest, opts ...grpc.CallOption) (*pb.QuantizationJobResult, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.StartQuantization(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) QuantizationProgress(ctx context.Context, in *pb.QuantizationProgressRequest, f func(update *pb.QuantizationProgressUpdate), opts ...grpc.CallOption) error {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
|
||||
stream, err := client.QuantizationProgress(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
update, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
f(update)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.StopQuantization(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
|
||||
@@ -147,6 +147,22 @@ func (e *embedBackend) ExportModel(ctx context.Context, in *pb.ExportModelReques
|
||||
return e.s.ExportModel(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) StartQuantization(ctx context.Context, in *pb.QuantizationRequest, opts ...grpc.CallOption) (*pb.QuantizationJobResult, error) {
|
||||
return e.s.StartQuantization(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) QuantizationProgress(ctx context.Context, in *pb.QuantizationProgressRequest, f func(update *pb.QuantizationProgressUpdate), opts ...grpc.CallOption) error {
|
||||
bs := &embedBackendQuantizationProgressStream{
|
||||
ctx: ctx,
|
||||
fn: f,
|
||||
}
|
||||
return e.s.QuantizationProgress(in, bs)
|
||||
}
|
||||
|
||||
func (e *embedBackend) StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.StopQuantization(ctx, in)
|
||||
}
|
||||
|
||||
var _ pb.Backend_FineTuneProgressServer = new(embedBackendFineTuneProgressStream)
|
||||
|
||||
type embedBackendFineTuneProgressStream struct {
|
||||
@@ -185,6 +201,44 @@ func (e *embedBackendFineTuneProgressStream) RecvMsg(m any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ pb.Backend_QuantizationProgressServer = new(embedBackendQuantizationProgressStream)
|
||||
|
||||
type embedBackendQuantizationProgressStream struct {
|
||||
ctx context.Context
|
||||
fn func(update *pb.QuantizationProgressUpdate)
|
||||
}
|
||||
|
||||
func (e *embedBackendQuantizationProgressStream) Send(update *pb.QuantizationProgressUpdate) error {
|
||||
e.fn(update)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendQuantizationProgressStream) SetHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendQuantizationProgressStream) SendHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendQuantizationProgressStream) SetTrailer(md metadata.MD) {
|
||||
}
|
||||
|
||||
func (e *embedBackendQuantizationProgressStream) Context() context.Context {
|
||||
return e.ctx
|
||||
}
|
||||
|
||||
func (e *embedBackendQuantizationProgressStream) SendMsg(m any) error {
|
||||
if x, ok := m.(*pb.QuantizationProgressUpdate); ok {
|
||||
return e.Send(x)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendQuantizationProgressStream) RecvMsg(m any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type embedBackendServerStream struct {
|
||||
ctx context.Context
|
||||
fn func(reply *pb.Reply)
|
||||
|
||||
@@ -42,6 +42,11 @@ type AIModel interface {
|
||||
StopFineTune(*pb.FineTuneStopRequest) error
|
||||
ListCheckpoints(*pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error)
|
||||
ExportModel(*pb.ExportModelRequest) error
|
||||
|
||||
// Quantization
|
||||
StartQuantization(*pb.QuantizationRequest) (*pb.QuantizationJobResult, error)
|
||||
QuantizationProgress(*pb.QuantizationProgressRequest, chan *pb.QuantizationProgressUpdate) error
|
||||
StopQuantization(*pb.QuantizationStopRequest) error
|
||||
}
|
||||
|
||||
func newReply(s string) *pb.Reply {
|
||||
|
||||
@@ -377,6 +377,51 @@ func (s *server) ExportModel(ctx context.Context, in *pb.ExportModelRequest) (*p
|
||||
return &pb.Result{Message: "Model exported", Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *server) StartQuantization(ctx context.Context, in *pb.QuantizationRequest) (*pb.QuantizationJobResult, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
res, err := s.llm.StartQuantization(in)
|
||||
if err != nil {
|
||||
return &pb.QuantizationJobResult{Success: false, Message: fmt.Sprintf("Error starting quantization: %s", err.Error())}, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *server) QuantizationProgress(in *pb.QuantizationProgressRequest, stream pb.Backend_QuantizationProgressServer) error {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
updateChan := make(chan *pb.QuantizationProgressUpdate)
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
for update := range updateChan {
|
||||
stream.Send(update)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
err := s.llm.QuantizationProgress(in, updateChan)
|
||||
<-done
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *server) StopQuantization(ctx context.Context, in *pb.QuantizationStopRequest) (*pb.Result, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
err := s.llm.StopQuantization(in)
|
||||
if err != nil {
|
||||
return &pb.Result{Message: fmt.Sprintf("Error stopping quantization: %s", err.Error()), Success: false}, err
|
||||
}
|
||||
return &pb.Result{Message: "Quantization stopped", Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *server) ModelMetadata(ctx context.Context, in *pb.ModelOptions) (*pb.ModelMetadataResponse, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
|
||||
Reference in New Issue
Block a user