feat: add (experimental) fine-tuning support with TRL (#9088)

* feat: add fine-tuning endpoint

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(experimental): add fine-tuning endpoint and TRL support

This changeset defines new GRPC signatues for Fine tuning backends, and
add TRL backend as initial fine-tuning engine. This implementation also
supports exporting to GGUF and automatically importing it to LocalAI
after fine-tuning.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* commit TRL backend, stop by killing process

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* move fine-tune to generic features

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* add evals, reorder menu

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fix tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-03-21 02:08:02 +01:00
committed by GitHub
parent f7e3aab4fc
commit d9c1db2b87
49 changed files with 5652 additions and 110 deletions

View File

@@ -0,0 +1,141 @@
# Debugging and Rebuilding Backends
When a backend fails at runtime (e.g. a gRPC method error, a Python import error, or a dependency conflict), use this guide to diagnose, fix, and rebuild.
## Architecture Overview
- **Source directory**: `backend/python/<name>/` (or `backend/go/<name>/`, `backend/cpp/<name>/`)
- **Installed directory**: `backends/<name>/` — this is what LocalAI actually runs. It is populated by `make backends/<name>` which builds a Docker image, exports it, and installs it via `local-ai backends install`.
- **Virtual environment**: `backends/<name>/venv/` — the installed Python venv (for Python backends). The Python binary is at `backends/<name>/venv/bin/python`.
Editing files in `backend/python/<name>/` does **not** affect the running backend until you rebuild with `make backends/<name>`.
## Diagnosing Failures
### 1. Check the logs
Backend gRPC processes log to LocalAI's stdout/stderr. Look for lines tagged with the backend's model ID:
```
GRPC stderr id="trl-finetune-127.0.0.1:37335" line="..."
```
Common error patterns:
- **"Method not implemented"** — the backend is missing a gRPC method that the Go side calls. The model loader (`pkg/model/initializers.go`) always calls `LoadModel` after `Health`; fine-tuning backends must implement it even as a no-op stub.
- **Python import errors / `AttributeError`** — usually a dependency version mismatch (e.g. `pyarrow` removing `PyExtensionType`).
- **"failed to load backend"** — the gRPC process crashed or never started. Check stderr lines for the traceback.
### 2. Test the Python environment directly
You can run the installed venv's Python to check imports without starting the full server:
```bash
backends/<name>/venv/bin/python -c "import datasets; print(datasets.__version__)"
```
If `pip` is missing from the venv, bootstrap it:
```bash
backends/<name>/venv/bin/python -m ensurepip
```
Then use `backends/<name>/venv/bin/python -m pip install ...` to test fixes in the installed venv before committing them to the source requirements.
### 3. Check upstream dependency constraints
When you hit a dependency conflict, check what the main library expects. For example, TRL's upstream `requirements.txt`:
```
https://github.com/huggingface/trl/blob/main/requirements.txt
```
Pin minimum versions in the backend's requirements files to match upstream.
## Common Fixes
### Missing gRPC methods
If the Go side calls a method the backend doesn't implement (e.g. `LoadModel`), add a no-op stub in `backend.py`:
```python
def LoadModel(self, request, context):
"""No-op — actual loading happens elsewhere."""
return backend_pb2.Result(success=True, message="OK")
```
The gRPC contract requires `LoadModel` to succeed for the model loader to return a usable client, even if the backend doesn't need upfront model loading.
### Dependency version conflicts
Python backends often break when a transitive dependency releases a breaking change (e.g. `pyarrow` removing `PyExtensionType`). Steps:
1. Identify the broken import in the logs
2. Test in the installed venv: `backends/<name>/venv/bin/python -c "import <module>"`
3. Check upstream requirements for version constraints
4. Update **all** requirements files in `backend/python/<name>/`:
- `requirements.txt` — base deps (grpcio, protobuf)
- `requirements-cpu.txt` — CPU-specific (includes PyTorch CPU index)
- `requirements-cublas12.txt` — CUDA 12
- `requirements-cublas13.txt` — CUDA 13
5. Rebuild: `make backends/<name>`
### PyTorch index conflicts (uv resolver)
The Docker build uses `uv` for pip installs. When `--extra-index-url` points to the PyTorch wheel index, `uv` may refuse to fetch packages like `requests` from PyPI if it finds a different version on the PyTorch index first. Fix this by adding `--index-strategy=unsafe-first-match` to `install.sh`:
```bash
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
installRequirements
```
Most Python backends already do this — check `backend/python/transformers/install.sh` or similar for reference.
## Rebuilding
### Rebuild a single backend
```bash
make backends/<name>
```
This runs the Docker build (`Dockerfile.python`), exports the image to `backend-images/<name>.tar`, and installs it into `backends/<name>/`. It also rebuilds the `local-ai` Go binary (without extra tags).
**Important**: If you were previously running with `GO_TAGS=auth`, the `make backends/<name>` step will overwrite your binary without that tag. Rebuild the Go binary afterward:
```bash
GO_TAGS=auth make build
```
### Rebuild and restart
After rebuilding a backend, you must restart LocalAI for it to pick up the new backend files. The backend gRPC process is spawned on demand when the model is first loaded.
```bash
# Kill existing process
kill <pid>
# Restart
./local-ai run --debug [your flags]
```
### Quick iteration (skip Docker rebuild)
For fast iteration on a Python backend's `backend.py` without a full Docker rebuild, you can edit the installed copy directly:
```bash
# Edit the installed copy
vim backends/<name>/backend.py
# Restart LocalAI to respawn the gRPC process
```
This is useful for testing but **does not persist** — the next `make backends/<name>` will overwrite it. Always commit fixes to the source in `backend/python/<name>/`.
## Verification
After fixing and rebuilding:
1. Start LocalAI and confirm the backend registers: look for `Registering backend name="<name>"` in the logs
2. Trigger the operation that failed (e.g. start a fine-tuning job)
3. Watch the GRPC stderr/stdout lines for the backend's model ID
4. Confirm no errors in the traceback

View File

@@ -118,6 +118,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-trl'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'true'
backend: "trl"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
@@ -366,6 +379,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-trl'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "trl"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
@@ -757,6 +783,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-13-trl'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "trl"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'l4t'
cuda-major-version: "13"
cuda-minor-version: "0"

View File

@@ -12,6 +12,7 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
## Quick Reference

View File

@@ -1,5 +1,5 @@
# Disable parallel execution for backend builds
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl
GOCMD=go
GOTEST=$(GOCMD) test
@@ -421,6 +421,7 @@ prepare-test-extra: protogen-python
$(MAKE) -C backend/python/voxcpm
$(MAKE) -C backend/python/whisperx
$(MAKE) -C backend/python/ace-step
$(MAKE) -C backend/python/trl
test-extra: prepare-test-extra
$(MAKE) -C backend/python/transformers test
@@ -440,6 +441,7 @@ test-extra: prepare-test-extra
$(MAKE) -C backend/python/voxcpm test
$(MAKE) -C backend/python/whisperx test
$(MAKE) -C backend/python/ace-step test
$(MAKE) -C backend/python/trl test
DOCKER_IMAGE?=local-ai
IMAGE_TYPE?=core
@@ -572,6 +574,7 @@ BACKEND_VOXCPM = voxcpm|python|.|false|true
BACKEND_WHISPERX = whisperx|python|.|false|true
BACKEND_ACE_STEP = ace-step|python|.|false|true
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
BACKEND_TRL = trl|python|.|false|true
# Helper function to build docker image for a backend
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
@@ -629,12 +632,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
# Pattern rule for docker-save targets
docker-save-%: backend-images
docker save local-ai-backend:$* -o backend-images/$*.tar
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl
########################################################
### Mock Backend for E2E Tests

View File

@@ -39,6 +39,13 @@ service Backend {
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
// Fine-tuning RPCs
rpc StartFineTune(FineTuneRequest) returns (FineTuneJobResult) {}
rpc FineTuneProgress(FineTuneProgressRequest) returns (stream FineTuneProgressUpdate) {}
rpc StopFineTune(FineTuneStopRequest) returns (Result) {}
rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {}
rpc ExportModel(ExportModelRequest) returns (Result) {}
}
// Define the empty request
@@ -528,3 +535,105 @@ message ModelMetadataResponse {
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
}
// Fine-tuning messages
message FineTuneRequest {
// Model identification
string model = 1; // HF model name or local path
string training_type = 2; // "lora", "loha", "lokr", "full" — what parameters to train
string training_method = 3; // "sft", "dpo", "grpo", "rloo", "reward", "kto", "orpo", "network_training"
// Adapter config (universal across LoRA/LoHa/LoKr for LLM + diffusion)
int32 adapter_rank = 10; // LoRA rank (r), default 16
int32 adapter_alpha = 11; // scaling factor, default 16
float adapter_dropout = 12; // default 0.0
repeated string target_modules = 13; // layer names to adapt
// Universal training hyperparameters
float learning_rate = 20; // default 2e-4
int32 num_epochs = 21; // default 3
int32 batch_size = 22; // default 2
int32 gradient_accumulation_steps = 23; // default 4
int32 warmup_steps = 24; // default 5
int32 max_steps = 25; // 0 = use epochs
int32 save_steps = 26; // 0 = only save final
float weight_decay = 27; // default 0.01
bool gradient_checkpointing = 28;
string optimizer = 29; // adamw_8bit, adamw, sgd, adafactor, prodigy
int32 seed = 30; // default 3407
string mixed_precision = 31; // fp16, bf16, fp8, no
// Dataset
string dataset_source = 40; // HF dataset ID, local file/dir path
string dataset_split = 41; // train, test, etc.
// Output
string output_dir = 50;
string job_id = 51; // client-assigned or auto-generated
// Resume training from a checkpoint
string resume_from_checkpoint = 55; // path to checkpoint dir to resume from
// Backend-specific AND method-specific extensibility
map<string, string> extra_options = 60;
}
message FineTuneJobResult {
string job_id = 1;
bool success = 2;
string message = 3;
}
message FineTuneProgressRequest {
string job_id = 1;
}
message FineTuneProgressUpdate {
string job_id = 1;
int32 current_step = 2;
int32 total_steps = 3;
float current_epoch = 4;
float total_epochs = 5;
float loss = 6;
float learning_rate = 7;
float grad_norm = 8;
float eval_loss = 9;
float eta_seconds = 10;
float progress_percent = 11;
string status = 12; // queued, caching, loading_model, loading_dataset, training, saving, completed, failed, stopped
string message = 13;
string checkpoint_path = 14; // set when a checkpoint is saved
string sample_path = 15; // set when a sample is generated (video/image backends)
map<string, float> extra_metrics = 16; // method-specific metrics
}
message FineTuneStopRequest {
string job_id = 1;
bool save_checkpoint = 2;
}
message ListCheckpointsRequest {
string output_dir = 1;
}
message ListCheckpointsResponse {
repeated CheckpointInfo checkpoints = 1;
}
message CheckpointInfo {
string path = 1;
int32 step = 2;
float epoch = 3;
float loss = 4;
string created_at = 5;
}
message ExportModelRequest {
string checkpoint_path = 1;
string output_path = 2;
string export_format = 3; // lora, loha, lokr, merged_16bit, merged_4bit, gguf, diffusers
string quantization_method = 4; // for GGUF: q4_k_m, q5_k_m, q8_0, f16, etc.
string model = 5; // base model name (for merge operations)
map<string, string> extra_options = 6;
}

View File

@@ -3030,3 +3030,54 @@
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-voxtral
- &trl
name: "trl"
alias: "trl"
license: apache-2.0
description: |
HuggingFace TRL fine-tuning backend. Supports SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO training methods.
Works on CPU and GPU.
urls:
- https://github.com/huggingface/trl
tags:
- fine-tuning
- LLM
- CPU
- GPU
- CUDA
capabilities:
default: "cpu-trl"
nvidia: "cuda12-trl"
nvidia-cuda-12: "cuda12-trl"
nvidia-cuda-13: "cuda13-trl"
## TRL backend images
- !!merge <<: *trl
name: "cpu-trl"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-trl"
mirrors:
- localai/localai-backends:latest-cpu-trl
- !!merge <<: *trl
name: "cpu-trl-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-trl"
mirrors:
- localai/localai-backends:master-cpu-trl
- !!merge <<: *trl
name: "cuda12-trl"
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda12-trl"
mirrors:
- localai/localai-backends:latest-cublas-cuda12-trl
- !!merge <<: *trl
name: "cuda12-trl-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda12-trl"
mirrors:
- localai/localai-backends:master-cublas-cuda12-trl
- !!merge <<: *trl
name: "cuda13-trl"
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda13-trl"
mirrors:
- localai/localai-backends:latest-cublas-cuda13-trl
- !!merge <<: *trl
name: "cuda13-trl-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl"
mirrors:
- localai/localai-backends:master-cublas-cuda13-trl

View File

@@ -0,0 +1,26 @@
# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export)
LLAMA_CPP_CONVERT_VERSION ?= master
.PHONY: trl
trl:
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
.PHONY: run
run: trl
@echo "Running trl..."
bash run.sh
@echo "trl run."
.PHONY: test
test: trl
@echo "Testing trl..."
bash test.sh
@echo "trl tested."
.PHONY: protogen-clean
protogen-clean:
$(RM) backend_pb2_grpc.py backend_pb2.py
.PHONY: clean
clean: protogen-clean
rm -rf venv __pycache__

View File

@@ -0,0 +1,860 @@
#!/usr/bin/env python3
"""
TRL fine-tuning backend for LocalAI.
Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO)
using standard HuggingFace transformers + PEFT. Works on both CPU and GPU.
"""
import argparse
import json
import os
import queue
import signal
import sys
import threading
import time
import uuid
from concurrent import futures
import grpc
import backend_pb2
import backend_pb2_grpc
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
class ProgressCallback:
"""HuggingFace TrainerCallback that pushes progress updates to a queue."""
def __init__(self, job_id, progress_queue, total_epochs):
self.job_id = job_id
self.progress_queue = progress_queue
self.total_epochs = total_epochs
def get_callback(self):
from transformers import TrainerCallback
parent = self
class _Callback(TrainerCallback):
def __init__(self):
self._train_start_time = None
def on_train_begin(self, args, state, control, **kwargs):
self._train_start_time = time.time()
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
eta = 0.0
if state.global_step > 0 and total_steps > 0 and self._train_start_time:
elapsed = time.time() - self._train_start_time
remaining_steps = total_steps - state.global_step
if state.global_step > 0:
eta = remaining_steps * (elapsed / state.global_step)
extra_metrics = {}
for k, v in logs.items():
if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
extra_metrics[k] = float(v)
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(logs.get('epoch', 0)),
total_epochs=float(parent.total_epochs),
loss=float(logs.get('loss', 0)),
learning_rate=float(logs.get('learning_rate', 0)),
grad_norm=float(logs.get('grad_norm', 0)),
eval_loss=float(logs.get('eval_loss', 0)),
eta_seconds=float(eta),
progress_percent=float(progress),
status="training",
extra_metrics=extra_metrics,
)
parent.progress_queue.put(update)
def on_prediction_step(self, args, state, control, **kwargs):
"""Send periodic updates during evaluation so the UI doesn't freeze."""
if not hasattr(self, '_eval_update_counter'):
self._eval_update_counter = 0
self._eval_update_counter += 1
# Throttle: send an update every 10 prediction steps
if self._eval_update_counter % 10 != 0:
return
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(state.epoch or 0),
total_epochs=float(parent.total_epochs),
progress_percent=float(progress),
status="training",
message=f"Evaluating... (batch {self._eval_update_counter})",
)
parent.progress_queue.put(update)
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
"""Report eval results once evaluation is done."""
# Reset prediction counter for next eval round
self._eval_update_counter = 0
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
eval_loss = 0.0
extra_metrics = {}
if metrics:
eval_loss = float(metrics.get('eval_loss', 0))
for k, v in metrics.items():
if isinstance(v, (int, float)) and k not in ('eval_loss', 'epoch'):
extra_metrics[k] = float(v)
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(state.epoch or 0),
total_epochs=float(parent.total_epochs),
eval_loss=eval_loss,
progress_percent=float(progress),
status="training",
message=f"Evaluation complete at step {state.global_step}",
extra_metrics=extra_metrics,
)
parent.progress_queue.put(update)
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
status="saving",
message=f"Checkpoint saved at step {state.global_step}",
checkpoint_path=checkpoint_path,
)
parent.progress_queue.put(update)
def on_train_end(self, args, state, control, **kwargs):
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=state.max_steps,
progress_percent=100.0,
status="completed",
message="Training completed",
)
parent.progress_queue.put(update)
return _Callback()
class ActiveJob:
"""Represents an active fine-tuning job."""
def __init__(self, job_id):
self.job_id = job_id
self.progress_queue = queue.Queue()
self.trainer = None
self.thread = None
self.model = None
self.tokenizer = None
self.error = None
self.completed = False
self.stopped = False
def _is_gated_repo_error(exc):
"""Check if an exception is caused by a gated HuggingFace repo requiring authentication."""
try:
from huggingface_hub.utils import GatedRepoError
if isinstance(exc, GatedRepoError):
return True
except ImportError:
pass
msg = str(exc).lower()
if "gated repo" in msg or "access to model" in msg:
return True
if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'):
if exc.response.status_code in (401, 403):
return True
return False
class BackendServicer(backend_pb2_grpc.BackendServicer):
def __init__(self):
self.active_job = None
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
"""Accept LoadModel — actual model loading happens in StartFineTune."""
return backend_pb2.Result(success=True, message="OK")
def StartFineTune(self, request, context):
if self.active_job is not None and not self.active_job.completed:
return backend_pb2.FineTuneJobResult(
job_id="",
success=False,
message="A fine-tuning job is already running",
)
job_id = request.job_id if request.job_id else str(uuid.uuid4())
job = ActiveJob(job_id)
self.active_job = job
# Start training in background thread
thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
job.thread = thread
thread.start()
return backend_pb2.FineTuneJobResult(
job_id=job_id,
success=True,
message="Fine-tuning job started",
)
def _run_training(self, request, job):
try:
self._do_training(request, job)
except Exception as e:
if _is_gated_repo_error(e):
msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
else:
msg = f"Training failed: {e}"
job.error = msg
job.completed = True
update = backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id,
status="failed",
message=msg,
)
job.progress_queue.put(update)
# Send sentinel
job.progress_queue.put(None)
def _do_training(self, request, job):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset
extra = dict(request.extra_options)
training_method = request.training_method or "sft"
training_type = request.training_type or "lora"
# Send loading status
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
))
# Determine device and dtype
device_map = "auto" if torch.cuda.is_available() else "cpu"
dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
# HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
# Load model
model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
if hf_token:
model_kwargs["token"] = hf_token
if extra.get("trust_remote_code", "false").lower() == "true":
model_kwargs["trust_remote_code"] = True
if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
from transformers import BitsAndBytesConfig
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
job.model = model
job.tokenizer = tokenizer
# Apply LoRA if requested
if training_type == "lora":
from peft import LoraConfig, get_peft_model
lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
target_modules = list(request.target_modules) if request.target_modules else None
peft_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules or "all-linear",
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
# Load dataset
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="loading_dataset", message="Loading dataset",
))
dataset_split = request.dataset_split or "train"
if os.path.exists(request.dataset_source):
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split)
elif request.dataset_source.endswith('.csv'):
dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split)
else:
dataset = load_dataset(request.dataset_source, split=dataset_split)
else:
dataset = load_dataset(request.dataset_source, split=dataset_split)
# Eval dataset setup
eval_dataset = None
eval_strategy = extra.get("eval_strategy", "steps")
eval_steps = int(extra.get("eval_steps", str(request.save_steps if request.save_steps > 0 else 500)))
if eval_strategy != "no":
eval_split = extra.get("eval_split")
eval_dataset_source = extra.get("eval_dataset_source")
if eval_split:
# Load a specific split as eval dataset
if os.path.exists(request.dataset_source):
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
eval_dataset = load_dataset("json", data_files=request.dataset_source, split=eval_split)
elif request.dataset_source.endswith('.csv'):
eval_dataset = load_dataset("csv", data_files=request.dataset_source, split=eval_split)
else:
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
else:
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
elif eval_dataset_source:
# Load eval dataset from a separate source
eval_dataset = load_dataset(eval_dataset_source, split="train")
else:
# Auto-split from training set
eval_split_ratio = float(extra.get("eval_split_ratio", "0.1"))
split = dataset.train_test_split(test_size=eval_split_ratio)
dataset = split["train"]
eval_dataset = split["test"]
if eval_strategy == "no":
eval_dataset = None
# Training config
output_dir = request.output_dir or f"./output-{job.job_id}"
num_epochs = request.num_epochs if request.num_epochs > 0 else 3
batch_size = request.batch_size if request.batch_size > 0 else 2
lr = request.learning_rate if request.learning_rate > 0 else 2e-4
grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4
warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5
weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01
max_steps = request.max_steps if request.max_steps > 0 else -1
save_steps = request.save_steps if request.save_steps > 0 else 500
seed = request.seed if request.seed > 0 else 3407
optimizer = request.optimizer or "adamw_torch"
# Checkpoint save controls
save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited
save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no
# CPU vs GPU training args (can be overridden via extra_options)
use_cpu = not torch.cuda.is_available()
common_train_kwargs = {}
if use_cpu:
common_train_kwargs["use_cpu"] = True
common_train_kwargs["fp16"] = False
common_train_kwargs["bf16"] = False
common_train_kwargs["gradient_checkpointing"] = False
else:
common_train_kwargs["bf16"] = True
common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing
# Allow extra_options to override training kwargs
for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"):
if flag in extra:
common_train_kwargs[flag] = extra[flag].lower() == "true"
# Create progress callback
progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs)
# Build save kwargs (shared across all methods)
_save_kwargs = {}
if save_strategy == "steps" and save_steps > 0:
_save_kwargs["save_steps"] = save_steps
_save_kwargs["save_strategy"] = "steps"
elif save_strategy == "epoch":
_save_kwargs["save_strategy"] = "epoch"
elif save_strategy == "no":
_save_kwargs["save_strategy"] = "no"
else:
_save_kwargs["save_steps"] = save_steps
_save_kwargs["save_strategy"] = "steps"
if save_total_limit:
_save_kwargs["save_total_limit"] = save_total_limit
# Eval kwargs
_eval_kwargs = {}
if eval_dataset is not None:
_eval_kwargs["eval_strategy"] = eval_strategy
_eval_kwargs["eval_steps"] = eval_steps
# Common training arguments shared by all methods
_common_args = dict(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
learning_rate=lr,
gradient_accumulation_steps=grad_accum,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
max_steps=max_steps,
seed=seed,
optim=optimizer,
logging_steps=1,
report_to="none",
**_save_kwargs,
**common_train_kwargs,
**_eval_kwargs,
)
# Select trainer based on training method
if training_method == "sft":
from trl import SFTTrainer, SFTConfig
max_length = int(extra.get("max_seq_length", "512"))
packing = extra.get("packing", "false").lower() == "true"
training_args = SFTConfig(
max_length=max_length,
packing=packing,
**_common_args,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "dpo":
from trl import DPOTrainer, DPOConfig
beta = float(extra.get("beta", "0.1"))
loss_type = extra.get("loss_type", "sigmoid")
max_length = int(extra.get("max_length", "512"))
training_args = DPOConfig(
beta=beta,
loss_type=loss_type,
max_length=max_length,
**_common_args,
)
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "grpo":
from trl import GRPOTrainer, GRPOConfig
num_generations = int(extra.get("num_generations", "4"))
max_completion_length = int(extra.get("max_completion_length", "256"))
training_args = GRPOConfig(
num_generations=num_generations,
max_completion_length=max_completion_length,
**_common_args,
)
# GRPO requires reward functions passed via extra_options as a JSON list
from reward_functions import build_reward_functions
reward_funcs = []
if extra.get("reward_funcs"):
reward_funcs = build_reward_functions(extra["reward_funcs"])
if not reward_funcs:
raise ValueError(
"GRPO requires at least one reward function. "
"Specify reward_functions in the request or "
"reward_funcs in extra_options."
)
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
reward_funcs=reward_funcs,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "orpo":
from trl import ORPOTrainer, ORPOConfig
beta = float(extra.get("beta", "0.1"))
max_length = int(extra.get("max_length", "512"))
training_args = ORPOConfig(
beta=beta,
max_length=max_length,
**_common_args,
)
trainer = ORPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "kto":
from trl import KTOTrainer, KTOConfig
beta = float(extra.get("beta", "0.1"))
max_length = int(extra.get("max_length", "512"))
training_args = KTOConfig(
beta=beta,
max_length=max_length,
**_common_args,
)
trainer = KTOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "rloo":
from trl import RLOOTrainer, RLOOConfig
num_generations = int(extra.get("num_generations", "4"))
max_completion_length = int(extra.get("max_completion_length", "256"))
training_args = RLOOConfig(
num_generations=num_generations,
max_new_tokens=max_completion_length,
**_common_args,
)
trainer = RLOOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "reward":
from trl import RewardTrainer, RewardConfig
max_length = int(extra.get("max_length", "512"))
training_args = RewardConfig(
max_length=max_length,
**_common_args,
)
trainer = RewardTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
else:
raise ValueError(f"Unsupported training method: {training_method}. "
"Supported: sft, dpo, grpo, orpo, kto, rloo, reward")
job.trainer = trainer
# Start training
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="training", message="Training started",
))
resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None
trainer.train(resume_from_checkpoint=resume_ckpt)
# Save final model
trainer.save_model(output_dir)
if tokenizer:
tokenizer.save_pretrained(output_dir)
job.completed = True
# Sentinel to signal stream end
job.progress_queue.put(None)
def FineTuneProgress(self, request, context):
if self.active_job is None or self.active_job.job_id != request.job_id:
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Job {request.job_id} not found")
return
job = self.active_job
while True:
try:
update = job.progress_queue.get(timeout=1.0)
if update is None:
break
yield update
if update.status in ("completed", "failed", "stopped"):
break
except queue.Empty:
if job.completed or job.stopped:
break
if not context.is_active():
break
continue
def StopFineTune(self, request, context):
# Stopping is handled by killing the process from Go via ShutdownModel.
return backend_pb2.Result(success=True, message="OK")
def ListCheckpoints(self, request, context):
output_dir = request.output_dir
if not os.path.isdir(output_dir):
return backend_pb2.ListCheckpointsResponse(checkpoints=[])
checkpoints = []
for entry in sorted(os.listdir(output_dir)):
if entry.startswith("checkpoint-"):
ckpt_path = os.path.join(output_dir, entry)
if not os.path.isdir(ckpt_path):
continue
step = 0
try:
step = int(entry.split("-")[1])
except (IndexError, ValueError):
pass
# Try to read trainer_state.json for metadata
loss = 0.0
epoch = 0.0
state_file = os.path.join(ckpt_path, "trainer_state.json")
if os.path.exists(state_file):
try:
with open(state_file) as f:
state = json.load(f)
if state.get("log_history"):
last_log = state["log_history"][-1]
loss = last_log.get("loss", 0.0)
epoch = last_log.get("epoch", 0.0)
except Exception:
pass
created_at = time.strftime(
"%Y-%m-%dT%H:%M:%SZ",
time.gmtime(os.path.getmtime(ckpt_path)),
)
checkpoints.append(backend_pb2.CheckpointInfo(
path=ckpt_path,
step=step,
epoch=float(epoch),
loss=float(loss),
created_at=created_at,
))
return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints)
def ExportModel(self, request, context):
export_format = request.export_format or "lora"
output_path = request.output_path
checkpoint_path = request.checkpoint_path
# Extract HF token for gated model access
extra = dict(request.extra_options) if request.extra_options else {}
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
if not checkpoint_path or not os.path.isdir(checkpoint_path):
return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}")
os.makedirs(output_path, exist_ok=True)
try:
if export_format == "lora":
# Just copy the adapter files
import shutil
for f in os.listdir(checkpoint_path):
src = os.path.join(checkpoint_path, f)
dst = os.path.join(output_path, f)
if os.path.isfile(src):
shutil.copy2(src, dst)
elif export_format in ("merged_16bit", "merged_4bit"):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
base_model_name = request.model
if not base_model_name:
return backend_pb2.Result(success=False, message="Base model name required for merge export")
dtype = torch.float16 if export_format == "merged_16bit" else torch.float32
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token)
model = PeftModel.from_pretrained(base_model, checkpoint_path)
merged = model.merge_and_unload()
merged.save_pretrained(output_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
tokenizer.save_pretrained(output_path)
elif export_format == "gguf":
import torch
import subprocess
import shutil
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
base_model_name = request.model
if not base_model_name:
return backend_pb2.Result(success=False, message="Base model name required for GGUF export")
# Step 1: Merge LoRA into base model
merge_dir = os.path.join(output_path, "_hf_merged")
os.makedirs(merge_dir, exist_ok=True)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token)
model = PeftModel.from_pretrained(base_model, checkpoint_path)
merged = model.merge_and_unload()
merged.save_pretrained(merge_dir)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
tokenizer.save_pretrained(merge_dir)
# Ensure tokenizer.model (SentencePiece) is present in merge_dir.
# Gemma models need this file for GGUF conversion to use the
# SentencePiece path; without it, the script falls back to BPE
# handling which fails on unrecognized pre-tokenizer hashes.
sp_model_path = os.path.join(merge_dir, "tokenizer.model")
if not os.path.exists(sp_model_path):
sp_copied = False
# Method 1: Load the slow tokenizer which keeps the SP model file
try:
slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token)
if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file):
import shutil as _shutil
_shutil.copy2(slow_tok.vocab_file, sp_model_path)
sp_copied = True
print(f"Copied tokenizer.model from slow tokenizer cache")
except Exception as e:
print(f"Slow tokenizer method failed: {e}")
# Method 2: Download from HF hub
if not sp_copied:
try:
from huggingface_hub import hf_hub_download
cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token)
import shutil as _shutil
_shutil.copy2(cached_sp, sp_model_path)
sp_copied = True
print(f"Copied tokenizer.model from HF hub")
except Exception as e:
print(f"HF hub download method failed: {e}")
if not sp_copied:
print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. "
"GGUF conversion may fail for SentencePiece models.")
# Free GPU memory before conversion
del merged, model, base_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Step 2: Convert to GGUF using convert_hf_to_gguf.py
quant = request.quantization_method or "auto"
outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"}
outtype = outtype_map.get(quant, "f16")
gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf"
gguf_path = os.path.join(output_path, gguf_filename)
script_dir = os.path.dirname(os.path.abspath(__file__))
convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py")
if not os.path.exists(convert_script):
return backend_pb2.Result(success=False,
message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.")
# Log merge_dir contents for debugging conversion issues
merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else []
print(f"Merge dir contents: {merge_files}", flush=True)
env = os.environ.copy()
env["NO_LOCAL_GGUF"] = "1"
cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path]
conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
if conv_result.returncode != 0:
diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}"
return backend_pb2.Result(success=False,
message=f"GGUF conversion failed: {diag}")
# Clean up intermediate merged model
shutil.rmtree(merge_dir, ignore_errors=True)
else:
return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}")
except Exception as e:
if _is_gated_repo_error(e):
return backend_pb2.Result(success=False,
message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
return backend_pb2.Result(success=False, message=f"Export failed: {e}")
return backend_pb2.Result(success=True, message=f"Model exported to {output_path}")
def serve(address):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
],
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True)
# Handle graceful shutdown
def stop(signum, frame):
server.stop(0)
sys.exit(0)
signal.signal(signal.SIGTERM, stop)
signal.signal(signal.SIGINT, stop)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend")
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
args = parser.parse_args()
serve(args.addr)

View File

@@ -0,0 +1,37 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
installRequirements
# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
if [ ! -f "${CONVERT_SCRIPT}" ]; then
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
curl -L --fail --retry 3 \
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available."
fi
# Install gguf package from the same llama.cpp commit to keep them in sync
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
if [ "x${USE_PIP:-}" == "xtrue" ]; then
pip install "${GGUF_PIP_SPEC}" || {
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
pip install "gguf>=0.16.0"
}
else
uv pip install "${GGUF_PIP_SPEC}" || {
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
uv pip install "gguf>=0.16.0"
}
fi

View File

@@ -0,0 +1,9 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece

View File

@@ -0,0 +1,9 @@
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece
bitsandbytes

View File

@@ -0,0 +1,9 @@
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece
bitsandbytes

View File

@@ -0,0 +1,3 @@
grpcio==1.78.1
protobuf
certifi

View File

@@ -0,0 +1,236 @@
"""
Built-in reward functions and inline function compiler for GRPO training.
All reward functions follow TRL's signature: (completions, **kwargs) -> list[float]
"""
import json
import re
import math
import string
import functools
# ---------------------------------------------------------------------------
# Built-in reward functions
# ---------------------------------------------------------------------------
def format_reward(completions, **kwargs):
"""Checks for <think>...</think> followed by an answer. Returns 1.0 or 0.0."""
pattern = re.compile(r"<think>.*?</think>\s*\S", re.DOTALL)
return [1.0 if pattern.search(c) else 0.0 for c in completions]
def reasoning_accuracy_reward(completions, **kwargs):
"""Extracts <answer>...</answer> content and compares to the expected answer."""
answers = kwargs.get("answer", [])
if not answers:
return [0.0] * len(completions)
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
scores = []
for i, c in enumerate(completions):
expected = answers[i] if i < len(answers) else ""
match = pattern.search(c)
if match:
extracted = match.group(1).strip()
scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0)
else:
scores.append(0.0)
return scores
def length_reward(completions, target_length=200, **kwargs):
"""Score based on proximity to target_length. Returns [0, 1]."""
scores = []
for c in completions:
length = len(c)
if target_length <= 0:
scores.append(0.0)
else:
diff = abs(length - target_length) / target_length
scores.append(max(0.0, 1.0 - diff))
return scores
def xml_tag_reward(completions, **kwargs):
"""Scores properly opened/closed XML tags (<think>, <answer>)."""
tags = ["think", "answer"]
scores = []
for c in completions:
tag_score = 0.0
for tag in tags:
if f"<{tag}>" in c and f"</{tag}>" in c:
tag_score += 0.5
scores.append(min(tag_score, 1.0))
return scores
def no_repetition_reward(completions, n=4, **kwargs):
"""Penalizes n-gram repetition. Returns [0, 1]."""
scores = []
for c in completions:
words = c.split()
if len(words) < n:
scores.append(1.0)
continue
ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
unique = len(set(ngrams))
total = len(ngrams)
scores.append(unique / total if total > 0 else 1.0)
return scores
def code_execution_reward(completions, **kwargs):
"""Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0."""
pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL)
scores = []
for c in completions:
match = pattern.search(c)
if not match:
scores.append(0.0)
continue
code = match.group(1)
try:
compile(code, "<inline>", "exec")
scores.append(1.0)
except SyntaxError:
scores.append(0.0)
return scores
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
BUILTIN_REGISTRY = {
"format_reward": format_reward,
"reasoning_accuracy_reward": reasoning_accuracy_reward,
"length_reward": length_reward,
"xml_tag_reward": xml_tag_reward,
"no_repetition_reward": no_repetition_reward,
"code_execution_reward": code_execution_reward,
}
# ---------------------------------------------------------------------------
# Inline function compiler
# ---------------------------------------------------------------------------
_SAFE_BUILTINS = {
"len": len, "int": int, "float": float, "str": str, "bool": bool,
"list": list, "dict": dict, "tuple": tuple, "set": set,
"range": range, "enumerate": enumerate, "zip": zip,
"map": map, "filter": filter, "sorted": sorted,
"min": min, "max": max, "sum": sum, "abs": abs, "round": round,
"any": any, "all": all, "isinstance": isinstance, "type": type,
"print": print, "True": True, "False": False, "None": None,
"ValueError": ValueError, "TypeError": TypeError,
"KeyError": KeyError, "IndexError": IndexError,
}
def compile_inline_reward(name, code):
"""Compile user-provided code into a reward function.
The code should be the body of a function that receives
`completions` (list[str]) and `**kwargs`, and returns list[float].
Available modules: re, math, json, string.
"""
func_source = (
f"def _user_reward_{name}(completions, **kwargs):\n"
+ "\n".join(f" {line}" for line in code.splitlines())
)
restricted_globals = {
"__builtins__": _SAFE_BUILTINS,
"re": re,
"math": math,
"json": json,
"string": string,
}
try:
compiled = compile(func_source, f"<inline-reward-{name}>", "exec")
except SyntaxError as e:
raise ValueError(f"Syntax error in inline reward function '{name}': {e}")
exec(compiled, restricted_globals)
func = restricted_globals[f"_user_reward_{name}"]
# Validate with a quick smoke test
try:
result = func(["test"], answer=["test"])
if not isinstance(result, list):
raise ValueError(
f"Inline reward function '{name}' must return a list, got {type(result).__name__}"
)
except Exception as e:
if "must return a list" in str(e):
raise
# Other errors during smoke test are acceptable (e.g. missing kwargs)
pass
return func
# ---------------------------------------------------------------------------
# Dispatcher
# ---------------------------------------------------------------------------
def build_reward_functions(specs_json):
"""Parse a JSON list of reward function specs and return a list of callables.
Each spec is a dict with:
- type: "builtin" or "inline"
- name: function name
- code: (inline only) Python function body
- params: (optional) dict of string params applied via functools.partial
"""
if isinstance(specs_json, str):
specs = json.loads(specs_json)
else:
specs = specs_json
if not isinstance(specs, list):
raise ValueError("reward_funcs must be a JSON array of reward function specs")
reward_funcs = []
for spec in specs:
spec_type = spec.get("type", "builtin")
name = spec.get("name", "")
params = spec.get("params", {})
if spec_type == "builtin":
if name not in BUILTIN_REGISTRY:
available = ", ".join(sorted(BUILTIN_REGISTRY.keys()))
raise ValueError(
f"Unknown builtin reward function '{name}'. Available: {available}"
)
func = BUILTIN_REGISTRY[name]
if params:
# Convert string params to appropriate types
typed_params = {}
for k, v in params.items():
try:
typed_params[k] = int(v)
except (ValueError, TypeError):
try:
typed_params[k] = float(v)
except (ValueError, TypeError):
typed_params[k] = v
func = functools.partial(func, **typed_params)
reward_funcs.append(func)
elif spec_type == "inline":
code = spec.get("code", "")
if not code.strip():
raise ValueError(f"Inline reward function '{name}' has no code")
func = compile_inline_reward(name, code)
reward_funcs.append(func)
else:
raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'")
return reward_funcs

10
backend/python/trl/run.sh Normal file
View File

@@ -0,0 +1,10 @@
#!/bin/bash
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
startBackend $@

View File

@@ -0,0 +1,58 @@
"""
Test script for the TRL fine-tuning gRPC backend.
"""
import unittest
import subprocess
import time
import grpc
import backend_pb2
import backend_pb2_grpc
class TestBackendServicer(unittest.TestCase):
"""Tests for the TRL fine-tuning gRPC service."""
def setUp(self):
self.service = subprocess.Popen(
["python3", "backend.py", "--addr", "localhost:50051"]
)
time.sleep(10)
def tearDown(self):
self.service.kill()
self.service.wait()
def test_server_startup(self):
"""Test that the server starts and responds to health checks."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_list_checkpoints_empty(self):
"""Test listing checkpoints on a non-existent directory."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.ListCheckpoints(
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
)
self.assertEqual(len(response.checkpoints), 0)
except Exception as err:
print(err)
self.fail("ListCheckpoints service failed")
finally:
self.tearDown()
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,11 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
runUnittests

View File

@@ -121,6 +121,9 @@ type RunCMD struct {
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
// Fine-tuning
EnableFineTuning bool `env:"LOCALAI_ENABLE_FINETUNING" default:"false" help:"Enable fine-tuning support" group:"finetuning"`
// Authentication
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"`
@@ -326,6 +329,11 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
opts = append(opts, config.WithAgentHubURL(r.AgentHubURL))
}
// Fine-tuning
if r.EnableFineTuning {
opts = append(opts, config.EnableFineTuning)
}
// Authentication
authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != ""
if authEnabled {

View File

@@ -97,6 +97,9 @@ type ApplicationConfig struct {
// Agent Pool (LocalAGI integration)
AgentPool AgentPoolConfig
// Fine-tuning
FineTuning FineTuningConfig
// Authentication & Authorization
Auth AuthConfig
}
@@ -142,6 +145,11 @@ type AgentPoolConfig struct {
AgentHubURL string // default: "https://agenthub.localai.io"
}
// FineTuningConfig holds configuration for fine-tuning support.
type FineTuningConfig struct {
Enabled bool
}
type AppOption func(*ApplicationConfig)
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
@@ -733,6 +741,12 @@ func WithAgentHubURL(url string) AppOption {
}
}
// Fine-tuning options
var EnableFineTuning = func(o *ApplicationConfig) {
o.FineTuning.Enabled = true
}
// Auth options
func WithAuthEnabled(enabled bool) AppOption {

View File

@@ -0,0 +1,205 @@
package importers
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/xlog"
)
// ImportLocalPath scans a local directory for exported model files and produces
// a config.ModelConfig with the correct backend, model path, and options.
// Paths in the returned config are relative to modelsPath when possible so that
// the YAML config remains portable.
//
// Detection order:
// 1. GGUF files (*.gguf) — uses llama-cpp backend
// 2. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter
// 3. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend
func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
// Make paths relative to the models directory (parent of dirPath)
// so config YAML stays portable.
modelsDir := filepath.Dir(dirPath)
relPath := func(absPath string) string {
if rel, err := filepath.Rel(modelsDir, absPath); err == nil {
return rel
}
return absPath
}
// 1. GGUF: check dirPath and dirPath_gguf/ (Unsloth convention)
ggufFile := findGGUF(dirPath)
if ggufFile == "" {
ggufSubdir := dirPath + "_gguf"
ggufFile = findGGUF(ggufSubdir)
}
if ggufFile != "" {
xlog.Info("ImportLocalPath: detected GGUF model", "path", ggufFile)
cfg := &config.ModelConfig{
Name: name,
Backend: "llama-cpp",
KnownUsecaseStrings: []string{"chat"},
Options: []string{"use_jinja:true"},
}
cfg.Model = relPath(ggufFile)
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.Description = buildDescription(dirPath, "GGUF")
return cfg, nil
}
// 2. LoRA adapter: look for adapter_config.json
adapterConfigPath := filepath.Join(dirPath, "adapter_config.json")
if fileExists(adapterConfigPath) {
xlog.Info("ImportLocalPath: detected LoRA adapter", "path", dirPath)
baseModel := readBaseModel(dirPath)
cfg := &config.ModelConfig{
Name: name,
Backend: "transformers",
KnownUsecaseStrings: []string{"chat"},
}
cfg.Model = baseModel
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
cfg.Description = buildDescription(dirPath, "LoRA adapter")
return cfg, nil
}
// Also check for adapter_model.safetensors or adapter_model.bin without adapter_config.json
if fileExists(filepath.Join(dirPath, "adapter_model.safetensors")) || fileExists(filepath.Join(dirPath, "adapter_model.bin")) {
xlog.Info("ImportLocalPath: detected LoRA adapter (by model files)", "path", dirPath)
baseModel := readBaseModel(dirPath)
cfg := &config.ModelConfig{
Name: name,
Backend: "transformers",
KnownUsecaseStrings: []string{"chat"},
}
cfg.Model = baseModel
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
cfg.Description = buildDescription(dirPath, "LoRA adapter")
return cfg, nil
}
// 3. Merged model: *.safetensors or pytorch_model*.bin + config.json
if fileExists(filepath.Join(dirPath, "config.json")) && (hasFileWithSuffix(dirPath, ".safetensors") || hasFileWithPrefix(dirPath, "pytorch_model")) {
xlog.Info("ImportLocalPath: detected merged model", "path", dirPath)
cfg := &config.ModelConfig{
Name: name,
Backend: "transformers",
KnownUsecaseStrings: []string{"chat"},
}
cfg.Model = relPath(dirPath)
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.Description = buildDescription(dirPath, "merged model")
return cfg, nil
}
return nil, fmt.Errorf("could not detect model format in directory %s", dirPath)
}
// findGGUF returns the path to the first .gguf file found in dir, or "".
func findGGUF(dir string) string {
entries, err := os.ReadDir(dir)
if err != nil {
return ""
}
for _, e := range entries {
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), ".gguf") {
return filepath.Join(dir, e.Name())
}
}
return ""
}
// readBaseModel reads the base model name from adapter_config.json or export_metadata.json.
func readBaseModel(dirPath string) string {
// Try adapter_config.json → base_model_name_or_path (TRL writes this)
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
var ac map[string]any
if json.Unmarshal(data, &ac) == nil {
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
return bm
}
}
}
// Try export_metadata.json → base_model (Unsloth writes this)
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
var meta map[string]any
if json.Unmarshal(data, &meta) == nil {
if bm, ok := meta["base_model"].(string); ok && bm != "" {
return bm
}
}
}
return ""
}
// buildDescription creates a human-readable description using available metadata.
func buildDescription(dirPath, formatLabel string) string {
base := ""
// Try adapter_config.json
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
var ac map[string]any
if json.Unmarshal(data, &ac) == nil {
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
base = bm
}
}
}
// Try export_metadata.json
if base == "" {
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
var meta map[string]any
if json.Unmarshal(data, &meta) == nil {
if bm, ok := meta["base_model"].(string); ok && bm != "" {
base = bm
}
}
}
}
if base != "" {
return fmt.Sprintf("Fine-tuned from %s (%s)", base, formatLabel)
}
return fmt.Sprintf("Fine-tuned model (%s)", formatLabel)
}
func fileExists(path string) bool {
info, err := os.Stat(path)
return err == nil && !info.IsDir()
}
func hasFileWithSuffix(dir, suffix string) bool {
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), suffix) {
return true
}
}
return false
}
func hasFileWithPrefix(dir, prefix string) bool {
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() && strings.HasPrefix(e.Name(), prefix) {
return true
}
}
return false
}

View File

@@ -0,0 +1,148 @@
package importers_test
import (
"encoding/json"
"os"
"path/filepath"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/gallery/importers"
)
var _ = Describe("ImportLocalPath", func() {
var tmpDir string
BeforeEach(func() {
var err error
tmpDir, err = os.MkdirTemp("", "importers-local-test")
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
os.RemoveAll(tmpDir)
})
Context("GGUF detection", func() {
It("detects a GGUF file in the directory", func() {
modelDir := filepath.Join(tmpDir, "my-model")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "model-q4_k_m.gguf"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("llama-cpp"))
Expect(cfg.Model).To(ContainSubstring(".gguf"))
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
Expect(cfg.KnownUsecaseStrings).To(ContainElement("chat"))
Expect(cfg.Options).To(ContainElement("use_jinja:true"))
})
It("detects GGUF in _gguf subdirectory", func() {
modelDir := filepath.Join(tmpDir, "my-model")
ggufDir := modelDir + "_gguf"
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.MkdirAll(ggufDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(ggufDir, "model.gguf"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("llama-cpp"))
})
})
Context("LoRA adapter detection", func() {
It("detects LoRA adapter via adapter_config.json", func() {
modelDir := filepath.Join(tmpDir, "lora-model")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
adapterConfig := map[string]any{
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
"peft_type": "LORA",
}
data, _ := json.Marshal(adapterConfig)
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "lora-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("transformers"))
Expect(cfg.Model).To(Equal("meta-llama/Llama-2-7b-hf"))
Expect(cfg.LLMConfig.LoraAdapter).To(Equal("lora-model"))
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
})
It("reads base model from export_metadata.json as fallback", func() {
modelDir := filepath.Join(tmpDir, "lora-unsloth")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
adapterConfig := map[string]any{"peft_type": "LORA"}
data, _ := json.Marshal(adapterConfig)
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
metadata := map[string]any{"base_model": "unsloth/tinyllama-bnb-4bit"}
data, _ = json.Marshal(metadata)
Expect(os.WriteFile(filepath.Join(modelDir, "export_metadata.json"), data, 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "lora-unsloth")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Model).To(Equal("unsloth/tinyllama-bnb-4bit"))
})
})
Context("Merged model detection", func() {
It("detects merged model with safetensors + config.json", func() {
modelDir := filepath.Join(tmpDir, "merged")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "merged")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("transformers"))
Expect(cfg.Model).To(Equal("merged"))
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
})
It("detects merged model with pytorch_model files", func() {
modelDir := filepath.Join(tmpDir, "merged-pt")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "pytorch_model-00001-of-00002.bin"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "merged-pt")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("transformers"))
Expect(cfg.Model).To(Equal("merged-pt"))
})
})
Context("fallback", func() {
It("returns error for empty directory", func() {
modelDir := filepath.Join(tmpDir, "empty")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
_, err := importers.ImportLocalPath(modelDir, "empty")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("could not detect model format"))
})
})
Context("description", func() {
It("includes base model name in description", func() {
modelDir := filepath.Join(tmpDir, "desc-test")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
adapterConfig := map[string]any{
"base_model_name_or_path": "TinyLlama/TinyLlama-1.1B",
}
data, _ := json.Marshal(adapterConfig)
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "desc-test")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Description).To(ContainSubstring("TinyLlama/TinyLlama-1.1B"))
Expect(cfg.Description).To(ContainSubstring("Fine-tuned from"))
})
})
})

View File

@@ -302,6 +302,17 @@ func API(application *application.Application) (*echo.Echo, error) {
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw, mcpMw)
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
// Fine-tuning routes
if application.ApplicationConfig().FineTuning.Enabled {
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
ftService := services.NewFineTuneService(
application.ApplicationConfig(),
application.ModelLoader(),
application.ModelConfigLoader(),
)
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
}
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)

View File

@@ -85,6 +85,18 @@ var RouteFeatureRegistry = []RouteFeature{
{"POST", "/stores/delete", FeatureStores},
{"POST", "/stores/get", FeatureStores},
{"POST", "/stores/find", FeatureStores},
// Fine-tuning
{"POST", "/api/fine-tuning/jobs", FeatureFineTuning},
{"GET", "/api/fine-tuning/jobs", FeatureFineTuning},
{"GET", "/api/fine-tuning/jobs/:id", FeatureFineTuning},
{"POST", "/api/fine-tuning/jobs/:id/stop", FeatureFineTuning},
{"DELETE", "/api/fine-tuning/jobs/:id", FeatureFineTuning},
{"GET", "/api/fine-tuning/jobs/:id/progress", FeatureFineTuning},
{"GET", "/api/fine-tuning/jobs/:id/checkpoints", FeatureFineTuning},
{"POST", "/api/fine-tuning/jobs/:id/export", FeatureFineTuning},
{"GET", "/api/fine-tuning/jobs/:id/download", FeatureFineTuning},
{"POST", "/api/fine-tuning/datasets", FeatureFineTuning},
}
// FeatureMeta describes a feature for the admin API/UI.
@@ -104,6 +116,13 @@ func AgentFeatureMetas() []FeatureMeta {
}
}
// GeneralFeatureMetas returns metadata for general features.
func GeneralFeatureMetas() []FeatureMeta {
return []FeatureMeta{
{FeatureFineTuning, "Fine-Tuning", false},
}
}
// APIFeatureMetas returns metadata for API endpoint features.
func APIFeatureMetas() []FeatureMeta {
return []FeatureMeta{

View File

@@ -32,6 +32,9 @@ const (
FeatureCollections = "collections"
FeatureMCPJobs = "mcp_jobs"
// General features (default OFF for new users)
FeatureFineTuning = "fine_tuning"
// API features (default ON for new users)
FeatureChat = "chat"
FeatureImages = "images"
@@ -52,6 +55,9 @@ const (
// AgentFeatures lists agent-related features (default OFF).
var AgentFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs}
// GeneralFeatures lists general features (default OFF).
var GeneralFeatures = []string{FeatureFineTuning}
// APIFeatures lists API endpoint features (default ON).
var APIFeatures = []string{
FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription,
@@ -60,7 +66,7 @@ var APIFeatures = []string{
}
// AllFeatures lists all known features (used by UI and validation).
var AllFeatures = append(append([]string{}, AgentFeatures...), APIFeatures...)
var AllFeatures = append(append(append([]string{}, AgentFeatures...), GeneralFeatures...), APIFeatures...)
// defaultOnFeatures is the set of features that default to ON when absent from a user's permission map.
var defaultOnFeatures = func() map[string]bool {

View File

@@ -80,13 +80,14 @@ func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
defer src.Close()
if err := svc.UploadToCollectionForUser(userID, name, file.Filename, src); err != nil {
key, err := svc.UploadToCollectionForUser(userID, name, file.Filename, src)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename})
return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename, "key": key})
}
}

View File

@@ -0,0 +1,362 @@
package localai
import (
"archive/tar"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
)
// StartFineTuneJobEndpoint starts a new fine-tuning job.
func StartFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
var req schema.FineTuneJobRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "Invalid request: " + err.Error(),
})
}
if req.Model == "" {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "model is required",
})
}
if req.DatasetSource == "" {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "dataset_source is required",
})
}
resp, err := ftService.StartJob(c.Request().Context(), userID, req)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusCreated, resp)
}
}
// ListFineTuneJobsEndpoint lists fine-tuning jobs for the current user.
func ListFineTuneJobsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobs := ftService.ListJobs(userID)
if jobs == nil {
jobs = []*schema.FineTuneJob{}
}
return c.JSON(http.StatusOK, jobs)
}
}
// GetFineTuneJobEndpoint gets a specific fine-tuning job.
func GetFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
job, err := ftService.GetJob(userID, jobID)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, job)
}
}
// StopFineTuneJobEndpoint stops a running fine-tuning job.
func StopFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
// Check for save_checkpoint query param
saveCheckpoint := c.QueryParam("save_checkpoint") == "true"
err := ftService.StopJob(c.Request().Context(), userID, jobID, saveCheckpoint)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, map[string]string{
"status": "stopped",
"message": "Fine-tuning job stopped",
})
}
}
// DeleteFineTuneJobEndpoint deletes a fine-tuning job and its data.
func DeleteFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
err := ftService.DeleteJob(userID, jobID)
if err != nil {
status := http.StatusInternalServerError
if strings.Contains(err.Error(), "not found") {
status = http.StatusNotFound
} else if strings.Contains(err.Error(), "cannot delete") {
status = http.StatusConflict
}
return c.JSON(status, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, map[string]string{
"status": "deleted",
"message": "Fine-tuning job deleted",
})
}
}
// FineTuneProgressEndpoint streams progress updates via SSE.
func FineTuneProgressEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
// Set SSE headers
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().WriteHeader(http.StatusOK)
err := ftService.StreamProgress(c.Request().Context(), userID, jobID, func(event *schema.FineTuneProgressEvent) {
data, err := json.Marshal(event)
if err != nil {
return
}
fmt.Fprintf(c.Response(), "data: %s\n\n", data)
c.Response().Flush()
})
if err != nil {
// If headers already sent, we can't send a JSON error
fmt.Fprintf(c.Response(), "data: {\"status\":\"error\",\"message\":%q}\n\n", err.Error())
c.Response().Flush()
}
return nil
}
}
// ListCheckpointsEndpoint lists checkpoints for a job.
func ListCheckpointsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
checkpoints, err := ftService.ListCheckpoints(c.Request().Context(), userID, jobID)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, map[string]any{
"checkpoints": checkpoints,
})
}
}
// ExportModelEndpoint exports a model from a checkpoint.
func ExportModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
var req schema.ExportRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "Invalid request: " + err.Error(),
})
}
modelName, err := ftService.ExportModel(c.Request().Context(), userID, jobID, req)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusAccepted, map[string]string{
"status": "exporting",
"message": "Export started for model '" + modelName + "'",
"model_name": modelName,
})
}
}
// DownloadExportedModelEndpoint streams the exported model directory as a tar.gz archive.
func DownloadExportedModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
modelDir, modelName, err := ftService.GetExportedModelPath(userID, jobID)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{
"error": err.Error(),
})
}
c.Response().Header().Set("Content-Type", "application/gzip")
c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s.tar.gz"`, modelName))
c.Response().WriteHeader(http.StatusOK)
gw := gzip.NewWriter(c.Response())
defer gw.Close()
tw := tar.NewWriter(gw)
defer tw.Close()
err = filepath.Walk(modelDir, func(path string, info os.FileInfo, walkErr error) error {
if walkErr != nil {
return walkErr
}
relPath, err := filepath.Rel(modelDir, path)
if err != nil {
return err
}
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
header.Name = filepath.Join(modelName, relPath)
if err := tw.WriteHeader(header); err != nil {
return err
}
if info.IsDir() {
return nil
}
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(tw, f)
return err
})
if err != nil {
// Headers already sent, can't return JSON error
return err
}
return nil
}
}
// ListFineTuneBackendsEndpoint returns installed backends tagged with "fine-tuning".
func ListFineTuneBackendsEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "failed to list backends: " + err.Error(),
})
}
type backendInfo struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Tags []string `json:"tags,omitempty"`
}
var result []backendInfo
for _, b := range backends {
if !b.Installed {
continue
}
hasTag := false
for _, t := range b.Tags {
if strings.EqualFold(t, "fine-tuning") {
hasTag = true
break
}
}
if !hasTag {
continue
}
name := b.Name
if b.Alias != "" {
name = b.Alias
}
result = append(result, backendInfo{
Name: name,
Description: b.Description,
Tags: b.Tags,
})
}
if result == nil {
result = []backendInfo{}
}
return c.JSON(http.StatusOK, result)
}
}
// UploadDatasetEndpoint handles dataset file upload.
func UploadDatasetEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
file, err := c.FormFile("file")
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "file is required",
})
}
src, err := file.Open()
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "failed to open file",
})
}
defer src.Close()
data, err := io.ReadAll(src)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "failed to read file",
})
}
path, err := ftService.UploadDataset(file.Filename, data)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, map[string]string{
"path": path,
})
}
}

View File

@@ -208,6 +208,32 @@
overflow: hidden;
}
.sidebar-section-toggle {
display: flex;
align-items: center;
justify-content: space-between;
width: 100%;
background: none;
border: none;
cursor: pointer;
font-family: inherit;
transition: color var(--duration-fast);
}
.sidebar-section-toggle:hover {
color: var(--color-text-secondary);
}
.sidebar-section-chevron {
font-size: 0.5rem;
transition: transform var(--duration-fast);
flex-shrink: 0;
}
.sidebar-section-toggle.open .sidebar-section-chevron {
transform: rotate(90deg);
}
.nav-item {
display: flex;
align-items: center;
@@ -392,6 +418,10 @@
display: none;
}
.sidebar.collapsed .sidebar-section-chevron {
display: none;
}
.sidebar.collapsed .nav-item {
justify-content: center;
padding: 8px 0;
@@ -612,14 +642,6 @@
.spinner-md .spinner-ring { width: 24px; height: 24px; }
.spinner-lg .spinner-ring { width: 40px; height: 40px; }
.spinner-logo {
animation: pulse 1.2s ease-in-out infinite;
object-fit: contain;
}
.spinner-sm .spinner-logo { width: 16px; height: 16px; }
.spinner-md .spinner-logo { width: 24px; height: 24px; }
.spinner-lg .spinner-logo { width: 40px; height: 40px; }
/* Model selector */
.model-selector {
background: var(--color-bg-tertiary);
@@ -2646,6 +2668,43 @@
font-size: 0.625rem;
}
/* Studio tabs */
.studio-tabs {
display: flex;
gap: 0;
border-bottom: 1px solid var(--color-border-subtle);
padding: 0 var(--spacing-xl);
background: var(--color-bg-primary);
position: sticky;
top: 0;
z-index: 10;
}
.studio-tab {
display: flex;
align-items: center;
gap: 6px;
background: none;
border: none;
padding: var(--spacing-sm) var(--spacing-md);
font-size: 0.8125rem;
font-family: inherit;
color: var(--color-text-secondary);
cursor: pointer;
border-bottom: 2px solid transparent;
transition: color var(--duration-fast), border-color var(--duration-fast);
}
.studio-tab:hover {
color: var(--color-text-primary);
}
.studio-tab-active {
color: var(--color-primary);
border-bottom-color: var(--color-primary);
font-weight: 500;
}
/* Two-column layout for media generation pages */
.media-layout {
display: grid;

View File

@@ -1,22 +1,8 @@
import { useState } from 'react'
import { apiUrl } from '../utils/basePath'
export default function LoadingSpinner({ size = 'md', className = '' }) {
const sizeClass = size === 'sm' ? 'spinner-sm' : size === 'lg' ? 'spinner-lg' : 'spinner-md'
const [imgFailed, setImgFailed] = useState(false)
return (
<div className={`spinner ${sizeClass} ${className}`}>
{imgFailed ? (
<div className="spinner-ring" />
) : (
<img
src={apiUrl('/static/logo.png')}
alt=""
className="spinner-logo"
onError={() => setImgFailed(true)}
/>
)}
<div className="spinner-ring" />
</div>
)
}

View File

@@ -1,37 +1,57 @@
import { useState, useEffect } from 'react'
import { NavLink, useNavigate } from 'react-router-dom'
import { NavLink, useNavigate, useLocation } from 'react-router-dom'
import ThemeToggle from './ThemeToggle'
import { useAuth } from '../context/AuthContext'
import { apiUrl } from '../utils/basePath'
const COLLAPSED_KEY = 'localai_sidebar_collapsed'
const SECTIONS_KEY = 'localai_sidebar_sections'
const mainItems = [
const topItems = [
{ path: '/app', icon: 'fas fa-home', label: 'Home' },
{ path: '/app/models', icon: 'fas fa-download', label: 'Install Models', adminOnly: true },
{ path: '/app/chat', icon: 'fas fa-comments', label: 'Chat' },
{ path: '/app/image', icon: 'fas fa-image', label: 'Images' },
{ path: '/app/video', icon: 'fas fa-video', label: 'Video' },
{ path: '/app/tts', icon: 'fas fa-music', label: 'TTS' },
{ path: '/app/sound', icon: 'fas fa-volume-high', label: 'Sound' },
{ path: '/app/studio', icon: 'fas fa-palette', label: 'Studio' },
{ path: '/app/talk', icon: 'fas fa-phone', label: 'Talk' },
{ path: '/app/usage', icon: 'fas fa-chart-bar', label: 'Usage', authOnly: true },
]
const agentItems = [
{ path: '/app/agents', icon: 'fas fa-robot', label: 'Agents' },
{ path: '/app/skills', icon: 'fas fa-wand-magic-sparkles', label: 'Skills' },
{ path: '/app/collections', icon: 'fas fa-database', label: 'Memory' },
{ path: '/app/agent-jobs', icon: 'fas fa-tasks', label: 'MCP CI Jobs', feature: 'mcp' },
]
const systemItems = [
{ path: '/app/users', icon: 'fas fa-users', label: 'Users', adminOnly: true, authOnly: true },
{ path: '/app/backends', icon: 'fas fa-server', label: 'Backends', adminOnly: true },
{ path: '/app/traces', icon: 'fas fa-chart-line', label: 'Traces', adminOnly: true },
{ path: '/app/p2p', icon: 'fas fa-circle-nodes', label: 'Swarm', adminOnly: true },
{ path: '/app/manage', icon: 'fas fa-desktop', label: 'System', adminOnly: true },
{ path: '/app/settings', icon: 'fas fa-cog', label: 'Settings', adminOnly: true },
const sections = [
{
id: 'tools',
title: 'Tools',
items: [
{ path: '/app/fine-tune', icon: 'fas fa-graduation-cap', label: 'Fine-Tune', feature: 'fine_tuning' },
],
},
{
id: 'agents',
title: 'Agents',
featureMap: {
'/app/agents': 'agents',
'/app/skills': 'skills',
'/app/collections': 'collections',
'/app/agent-jobs': 'mcp_jobs',
},
items: [
{ path: '/app/agents', icon: 'fas fa-robot', label: 'Agents' },
{ path: '/app/skills', icon: 'fas fa-wand-magic-sparkles', label: 'Skills' },
{ path: '/app/collections', icon: 'fas fa-database', label: 'Memory' },
{ path: '/app/agent-jobs', icon: 'fas fa-tasks', label: 'MCP CI Jobs', feature: 'mcp' },
],
},
{
id: 'system',
title: 'System',
items: [
{ path: '/app/usage', icon: 'fas fa-chart-bar', label: 'Usage', authOnly: true },
{ path: '/app/users', icon: 'fas fa-users', label: 'Users', adminOnly: true, authOnly: true },
{ path: '/app/backends', icon: 'fas fa-server', label: 'Backends', adminOnly: true },
{ path: '/app/traces', icon: 'fas fa-chart-line', label: 'Traces', adminOnly: true },
{ path: '/app/p2p', icon: 'fas fa-circle-nodes', label: 'Swarm', adminOnly: true },
{ path: '/app/manage', icon: 'fas fa-desktop', label: 'System', adminOnly: true },
{ path: '/app/settings', icon: 'fas fa-cog', label: 'Settings', adminOnly: true },
],
},
]
function NavItem({ item, onClose, collapsed }) {
@@ -51,18 +71,47 @@ function NavItem({ item, onClose, collapsed }) {
)
}
function loadSectionState() {
try {
const stored = localStorage.getItem(SECTIONS_KEY)
return stored ? JSON.parse(stored) : {}
} catch (_) {
return {}
}
}
function saveSectionState(state) {
try { localStorage.setItem(SECTIONS_KEY, JSON.stringify(state)) } catch (_) { /* ignore */ }
}
export default function Sidebar({ isOpen, onClose }) {
const [features, setFeatures] = useState({})
const [collapsed, setCollapsed] = useState(() => {
try { return localStorage.getItem(COLLAPSED_KEY) === 'true' } catch (_) { return false }
})
const [openSections, setOpenSections] = useState(loadSectionState)
const { isAdmin, authEnabled, user, logout, hasFeature } = useAuth()
const navigate = useNavigate()
const location = useLocation()
useEffect(() => {
fetch(apiUrl('/api/features')).then(r => r.json()).then(setFeatures).catch(() => {})
}, [])
// Auto-expand section containing the active route
useEffect(() => {
for (const section of sections) {
const match = section.items.some(item => location.pathname.startsWith(item.path))
if (match && !openSections[section.id]) {
setOpenSections(prev => {
const next = { ...prev, [section.id]: true }
saveSectionState(next)
return next
})
}
}
}, [location.pathname])
const toggleCollapse = () => {
setCollapsed(prev => {
const next = !prev
@@ -72,17 +121,34 @@ export default function Sidebar({ isOpen, onClose }) {
})
}
const visibleMainItems = mainItems.filter(item => {
if (item.adminOnly && !isAdmin) return false
if (item.authOnly && !authEnabled) return false
return true
})
const toggleSection = (id) => {
setOpenSections(prev => {
const next = { ...prev, [id]: !prev[id] }
saveSectionState(next)
return next
})
}
const visibleSystemItems = systemItems.filter(item => {
const filterItem = (item) => {
if (item.adminOnly && !isAdmin) return false
if (item.authOnly && !authEnabled) return false
if (item.feature && features[item.feature] === false) return false
if (item.feature && !hasFeature(item.feature)) return false
return true
})
}
const visibleTopItems = topItems.filter(filterItem)
const getVisibleSectionItems = (section) => {
return section.items.filter(item => {
if (!filterItem(item)) return false
if (section.featureMap) {
const featureName = section.featureMap[item.path]
return featureName ? hasFeature(featureName) : isAdmin
}
return true
})
}
return (
<>
@@ -104,57 +170,57 @@ export default function Sidebar({ isOpen, onClose }) {
{/* Navigation */}
<nav className="sidebar-nav">
{/* Main section */}
{/* Top-level items */}
<div className="sidebar-section">
{visibleMainItems.map(item => (
{visibleTopItems.map(item => (
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
))}
</div>
{/* Agents section (per-feature permissions) */}
{features.agents !== false && (() => {
const featureMap = {
'/app/agents': 'agents',
'/app/skills': 'skills',
'/app/collections': 'collections',
'/app/agent-jobs': 'mcp_jobs',
}
const visibleAgentItems = agentItems.filter(item => {
if (item.feature && features[item.feature] === false) return false
const featureName = featureMap[item.path]
return featureName ? hasFeature(featureName) : isAdmin
})
if (visibleAgentItems.length === 0) return null
{/* Collapsible sections */}
{sections.map(section => {
// For agents section, check global feature flag
if (section.id === 'agents' && features.agents === false) return null
const visibleItems = getVisibleSectionItems(section)
if (visibleItems.length === 0) return null
const isSectionOpen = openSections[section.id]
const showItems = isSectionOpen || collapsed
return (
<div className="sidebar-section">
<div className="sidebar-section-title">Agents</div>
{visibleAgentItems.map(item => (
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
))}
<div key={section.id} className="sidebar-section">
<button
className={`sidebar-section-title sidebar-section-toggle ${isSectionOpen ? 'open' : ''}`}
onClick={() => toggleSection(section.id)}
title={collapsed ? section.title : undefined}
>
<span>{section.title}</span>
<i className="fas fa-chevron-right sidebar-section-chevron" />
</button>
{showItems && (
<div className="sidebar-section-items">
{section.id === 'system' && (
<a
href={apiUrl('/swagger/index.html')}
target="_blank"
rel="noopener noreferrer"
className="nav-item"
title={collapsed ? 'API' : undefined}
>
<i className="fas fa-code nav-icon" />
<span className="nav-label">API</span>
<i className="fas fa-external-link-alt nav-external" />
</a>
)}
{visibleItems.map(item => (
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
))}
</div>
)}
</div>
)
})()}
{/* System section */}
<div className="sidebar-section">
{visibleSystemItems.length > 0 && (
<div className="sidebar-section-title">System</div>
)}
<a
href={apiUrl('/swagger/index.html')}
target="_blank"
rel="noopener noreferrer"
className="nav-item"
title={collapsed ? 'API' : undefined}
>
<i className="fas fa-code nav-icon" />
<span className="nav-label">API</span>
<i className="fas fa-external-link-alt nav-external" />
</a>
{visibleSystemItems.map(item => (
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
))}
</div>
})}
</nav>
{/* Footer */}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,48 @@
import { useSearchParams } from 'react-router-dom'
import ImageGen from './ImageGen'
import VideoGen from './VideoGen'
import TTS from './TTS'
import Sound from './Sound'
const TABS = [
{ key: 'images', label: 'Images', icon: 'fas fa-image' },
{ key: 'video', label: 'Video', icon: 'fas fa-video' },
{ key: 'tts', label: 'TTS', icon: 'fas fa-headphones' },
{ key: 'sound', label: 'Sound', icon: 'fas fa-music' },
]
const TAB_COMPONENTS = {
images: ImageGen,
video: VideoGen,
tts: TTS,
sound: Sound,
}
export default function Studio() {
const [searchParams, setSearchParams] = useSearchParams()
const activeTab = searchParams.get('tab') || 'images'
const setTab = (key) => {
setSearchParams({ tab: key }, { replace: true })
}
const ActiveComponent = TAB_COMPONENTS[activeTab] || ImageGen
return (
<div>
<div className="studio-tabs">
{TABS.map(tab => (
<button
key={tab.key}
className={`studio-tab${activeTab === tab.key ? ' studio-tab-active' : ''}`}
onClick={() => setTab(tab.key)}
>
<i className={tab.icon} />
<span>{tab.label}</span>
</button>
))}
</div>
<ActiveComponent />
</div>
)
}

View File

@@ -45,9 +45,11 @@ function PermissionSummary({ user, onClick }) {
const perms = user.permissions || {}
const apiFeatures = ['chat', 'images', 'audio_speech', 'audio_transcription', 'vad', 'detection', 'video', 'embeddings', 'sound']
const agentFeatures = ['agents', 'skills', 'collections', 'mcp_jobs']
const generalFeatures = ['fine_tuning']
const apiOn = apiFeatures.filter(f => perms[f] !== false && (perms[f] === true || perms[f] === undefined)).length
const agentOn = agentFeatures.filter(f => perms[f]).length
const generalOn = generalFeatures.filter(f => perms[f]).length
const modelRestricted = user.allowed_models?.enabled
@@ -58,7 +60,7 @@ function PermissionSummary({ user, onClick }) {
title="Edit permissions"
>
<i className="fas fa-shield-halved" />
{apiOn}/{apiFeatures.length} API, {agentOn}/{agentFeatures.length} Agent
{apiOn}/{apiFeatures.length} API, {agentOn}/{agentFeatures.length} Agent, {generalOn}/{generalFeatures.length} Features
{modelRestricted && ' | Models restricted'}
</button>
)
@@ -71,6 +73,7 @@ function PermissionsModal({ user, featureMeta, availableModels, onClose, onSave,
const apiFeatures = featureMeta?.api_features || []
const agentFeatures = featureMeta?.agent_features || []
const generalFeatures = featureMeta?.general_features || []
useEffect(() => {
const handleKeyDown = (e) => {
@@ -189,6 +192,33 @@ function PermissionsModal({ user, featureMeta, availableModels, onClose, onSave,
</div>
</div>
{/* General Features */}
{generalFeatures.length > 0 && (
<div className="perm-section">
<div className="perm-section-header">
<strong className="perm-section-title">
<i className="fas fa-sliders" />
Features
</strong>
<div className="action-group">
<button className="btn btn-sm btn-secondary perm-btn-all-none" onClick={() => setAllFeatures(generalFeatures, true)}>All</button>
<button className="btn btn-sm btn-secondary perm-btn-all-none" onClick={() => setAllFeatures(generalFeatures, false)}>None</button>
</div>
</div>
<div className="perm-grid">
{generalFeatures.map(f => (
<button
key={f.key}
className={`btn btn-sm ${permissions[f.key] ? 'btn-primary' : 'btn-secondary'} perm-btn-feature`}
onClick={() => toggleFeature(f.key)}
>
{f.label}
</button>
))}
</div>
</div>
)}
{/* Model Access */}
<div className="perm-section">
<div className="perm-section-header">
@@ -510,6 +540,9 @@ export default function Users() {
{ key: 'collections', label: 'Collections', default: false },
{ key: 'mcp_jobs', label: 'MCP CI Jobs', default: false },
],
general_features: [
{ key: 'fine_tuning', label: 'Fine-Tuning', default: false },
],
})
}
}, [])

View File

@@ -31,6 +31,8 @@ import ImportModel from './pages/ImportModel'
import BackendLogs from './pages/BackendLogs'
import Explorer from './pages/Explorer'
import Login from './pages/Login'
import FineTune from './pages/FineTune'
import Studio from './pages/Studio'
import NotFound from './pages/NotFound'
import Usage from './pages/Usage'
import Users from './pages/Users'
@@ -44,6 +46,7 @@ function BrowseRedirect() {
return <Navigate to={`/app/${splat || ''}`} replace />
}
function Admin({ children }) {
return <RequireAdmin>{children}</RequireAdmin>
}
@@ -65,6 +68,7 @@ const appChildren = [
{ path: 'tts/:model', element: <TTS /> },
{ path: 'sound', element: <Sound /> },
{ path: 'sound/:model', element: <Sound /> },
{ path: 'studio', element: <Studio /> },
{ path: 'talk', element: <Talk /> },
{ path: 'usage', element: <Usage /> },
{ path: 'account', element: <Account /> },
@@ -90,6 +94,7 @@ const appChildren = [
{ path: 'agent-jobs/tasks/:id', element: <Feature feature="mcp_jobs"><AgentTaskDetails /></Feature> },
{ path: 'agent-jobs/tasks/:id/edit', element: <Feature feature="mcp_jobs"><AgentTaskDetails /></Feature> },
{ path: 'agent-jobs/jobs/:id', element: <Feature feature="mcp_jobs"><AgentJobDetails /></Feature> },
{ path: 'fine-tune', element: <Feature feature="fine_tuning"><FineTune /></Feature> },
{ path: 'model-editor/:name', element: <Admin><ModelEditor /></Admin> },
{ path: 'pipeline-editor', element: <Admin><PipelineEditor /></Admin> },
{ path: 'pipeline-editor/:name', element: <Admin><PipelineEditor /></Admin> },

View File

@@ -380,6 +380,25 @@ export const apiKeysApi = {
revoke: (id) => fetchJSON(`/api/auth/api-keys/${encodeURIComponent(id)}`, { method: 'DELETE' }),
}
// Fine-tuning API
export const fineTuneApi = {
listBackends: () => fetchJSON('/api/fine-tuning/backends'),
startJob: (data) => postJSON('/api/fine-tuning/jobs', data),
listJobs: () => fetchJSON('/api/fine-tuning/jobs'),
getJob: (id) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}`),
stopJob: (id, saveCheckpoint) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}/stop?save_checkpoint=${saveCheckpoint ? 'true' : 'false'}`, { method: 'POST' }),
deleteJob: (id) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}`, { method: 'DELETE' }),
listCheckpoints: (id) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}/checkpoints`),
exportModel: (id, data) => postJSON(`/api/fine-tuning/jobs/${enc(id)}/export`, data),
uploadDataset: (file) => {
const formData = new FormData()
formData.append('file', file)
return fetch(apiUrl('/api/fine-tuning/datasets'), { method: 'POST', body: formData }).then(handleResponse)
},
progressUrl: (id) => apiUrl(`/api/fine-tuning/jobs/${enc(id)}/progress`),
downloadUrl: (id) => apiUrl(`/api/fine-tuning/jobs/${enc(id)}/download`),
}
// File to base64 helper
export function fileToBase64(file) {
return new Promise((resolve, reject) => {

View File

@@ -777,9 +777,10 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
}
return c.JSON(http.StatusOK, map[string]interface{}{
"agent_features": auth.AgentFeatureMetas(),
"api_features": auth.APIFeatureMetas(),
"models": modelNames,
"agent_features": auth.AgentFeatureMetas(),
"general_features": auth.GeneralFeatureMetas(),
"api_features": auth.APIFeatureMetas(),
"models": modelNames,
})
}, adminMw)

View File

@@ -0,0 +1,42 @@
package routes
import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/services"
)
// RegisterFineTuningRoutes registers fine-tuning API routes.
func RegisterFineTuningRoutes(e *echo.Echo, ftService *services.FineTuneService, appConfig *config.ApplicationConfig, fineTuningMw echo.MiddlewareFunc) {
if ftService == nil {
return
}
// Service readiness middleware
readyMw := func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if ftService == nil {
return c.JSON(http.StatusServiceUnavailable, map[string]string{
"error": "fine-tuning service is not available",
})
}
return next(c)
}
}
ft := e.Group("/api/fine-tuning", readyMw, fineTuningMw)
ft.GET("/backends", localai.ListFineTuneBackendsEndpoint(appConfig))
ft.POST("/jobs", localai.StartFineTuneJobEndpoint(ftService))
ft.GET("/jobs", localai.ListFineTuneJobsEndpoint(ftService))
ft.GET("/jobs/:id", localai.GetFineTuneJobEndpoint(ftService))
ft.POST("/jobs/:id/stop", localai.StopFineTuneJobEndpoint(ftService))
ft.DELETE("/jobs/:id", localai.DeleteFineTuneJobEndpoint(ftService))
ft.GET("/jobs/:id/progress", localai.FineTuneProgressEndpoint(ftService))
ft.GET("/jobs/:id/checkpoints", localai.ListCheckpointsEndpoint(ftService))
ft.POST("/jobs/:id/export", localai.ExportModelEndpoint(ftService))
ft.GET("/jobs/:id/download", localai.DownloadExportedModelEndpoint(ftService))
ft.POST("/datasets", localai.UploadDatasetEndpoint(ftService))
}

View File

@@ -134,8 +134,9 @@ func RegisterLocalAIRoutes(router *echo.Echo,
router.GET("/api/features", func(c echo.Context) error {
return c.JSON(200, map[string]bool{
"agents": appConfig.AgentPool.Enabled,
"mcp": !appConfig.DisableMCP,
"agents": appConfig.AgentPool.Enabled,
"mcp": !appConfig.DisableMCP,
"fine_tuning": appConfig.FineTuning.Enabled,
})
})

111
core/schema/finetune.go Normal file
View File

@@ -0,0 +1,111 @@
package schema
// RewardFunctionSpec defines a reward function for GRPO training.
type RewardFunctionSpec struct {
Type string `json:"type"` // "builtin" or "inline"
Name string `json:"name"`
Code string `json:"code,omitempty"` // inline only
Params map[string]string `json:"params,omitempty"`
}
// FineTuneJobRequest is the REST API request to start a fine-tuning job.
type FineTuneJobRequest struct {
Model string `json:"model"`
Backend string `json:"backend"` // "trl"
TrainingType string `json:"training_type,omitempty"` // lora, loha, lokr, full
TrainingMethod string `json:"training_method,omitempty"` // sft, dpo, grpo, rloo, reward, kto, orpo
// Adapter config
AdapterRank int32 `json:"adapter_rank,omitempty"`
AdapterAlpha int32 `json:"adapter_alpha,omitempty"`
AdapterDropout float32 `json:"adapter_dropout,omitempty"`
TargetModules []string `json:"target_modules,omitempty"`
// Training hyperparameters
LearningRate float32 `json:"learning_rate,omitempty"`
NumEpochs int32 `json:"num_epochs,omitempty"`
BatchSize int32 `json:"batch_size,omitempty"`
GradientAccumulationSteps int32 `json:"gradient_accumulation_steps,omitempty"`
WarmupSteps int32 `json:"warmup_steps,omitempty"`
MaxSteps int32 `json:"max_steps,omitempty"`
SaveSteps int32 `json:"save_steps,omitempty"`
WeightDecay float32 `json:"weight_decay,omitempty"`
GradientCheckpointing bool `json:"gradient_checkpointing,omitempty"`
Optimizer string `json:"optimizer,omitempty"`
Seed int32 `json:"seed,omitempty"`
MixedPrecision string `json:"mixed_precision,omitempty"`
// Dataset
DatasetSource string `json:"dataset_source"`
DatasetSplit string `json:"dataset_split,omitempty"`
// Resume from a checkpoint
ResumeFromCheckpoint string `json:"resume_from_checkpoint,omitempty"`
// GRPO reward functions
RewardFunctions []RewardFunctionSpec `json:"reward_functions,omitempty"`
// Backend-specific and method-specific options
ExtraOptions map[string]string `json:"extra_options,omitempty"`
}
// FineTuneJob represents a fine-tuning job with its current state.
type FineTuneJob struct {
ID string `json:"id"`
UserID string `json:"user_id,omitempty"`
Model string `json:"model"`
Backend string `json:"backend"`
ModelID string `json:"model_id,omitempty"` // backend model loader ID
TrainingType string `json:"training_type"`
TrainingMethod string `json:"training_method"`
Status string `json:"status"` // queued, loading_model, loading_dataset, training, saving, completed, failed, stopped
Message string `json:"message,omitempty"`
OutputDir string `json:"output_dir"`
ExtraOptions map[string]string `json:"extra_options,omitempty"`
CreatedAt string `json:"created_at"`
// Export state (tracked separately from training status)
ExportStatus string `json:"export_status,omitempty"` // "", "exporting", "completed", "failed"
ExportMessage string `json:"export_message,omitempty"`
ExportModelName string `json:"export_model_name,omitempty"` // registered model name after export
// Full config for resume/reuse
Config *FineTuneJobRequest `json:"config,omitempty"`
}
// FineTuneJobResponse is the REST API response when creating a job.
type FineTuneJobResponse struct {
ID string `json:"id"`
Status string `json:"status"`
Message string `json:"message"`
}
// FineTuneProgressEvent is an SSE event for training progress.
type FineTuneProgressEvent struct {
JobID string `json:"job_id"`
CurrentStep int32 `json:"current_step"`
TotalSteps int32 `json:"total_steps"`
CurrentEpoch float32 `json:"current_epoch"`
TotalEpochs float32 `json:"total_epochs"`
Loss float32 `json:"loss"`
LearningRate float32 `json:"learning_rate"`
GradNorm float32 `json:"grad_norm"`
EvalLoss float32 `json:"eval_loss"`
EtaSeconds float32 `json:"eta_seconds"`
ProgressPercent float32 `json:"progress_percent"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
CheckpointPath string `json:"checkpoint_path,omitempty"`
SamplePath string `json:"sample_path,omitempty"`
ExtraMetrics map[string]float32 `json:"extra_metrics,omitempty"`
}
// ExportRequest is the REST API request to export a model.
type ExportRequest struct {
Name string `json:"name,omitempty"` // model name for LocalAI (auto-generated if empty)
CheckpointPath string `json:"checkpoint_path"`
ExportFormat string `json:"export_format"` // lora, merged_16bit, merged_4bit, gguf
QuantizationMethod string `json:"quantization_method"` // for GGUF: q4_k_m, q5_k_m, q8_0, f16
Model string `json:"model,omitempty"` // base model name for merge
ExtraOptions map[string]string `json:"extra_options,omitempty"`
}

View File

@@ -1042,7 +1042,7 @@ func (s *AgentPoolService) CreateCollection(name string) error {
return s.collectionsBackend.CreateCollection(name)
}
func (s *AgentPoolService) UploadToCollection(collection, filename string, fileBody io.Reader) error {
func (s *AgentPoolService) UploadToCollection(collection, filename string, fileBody io.Reader) (string, error) {
return s.collectionsBackend.Upload(collection, filename, fileBody)
}
@@ -1554,10 +1554,10 @@ func (s *AgentPoolService) CreateCollectionForUser(userID, name string) error {
}
// UploadToCollectionForUser uploads to a collection for a specific user.
func (s *AgentPoolService) UploadToCollectionForUser(userID, collection, filename string, fileBody io.Reader) error {
func (s *AgentPoolService) UploadToCollectionForUser(userID, collection, filename string, fileBody io.Reader) (string, error) {
backend, err := s.CollectionsBackendForUser(userID)
if err != nil {
return err
return "", err
}
return backend.Upload(collection, filename, fileBody)
}

700
core/services/finetune.go Normal file
View File

@@ -0,0 +1,700 @@
package services
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery/importers"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/xlog"
"gopkg.in/yaml.v3"
)
// FineTuneService manages fine-tuning jobs and their lifecycle.
type FineTuneService struct {
appConfig *config.ApplicationConfig
modelLoader *model.ModelLoader
configLoader *config.ModelConfigLoader
mu sync.Mutex
jobs map[string]*schema.FineTuneJob
}
// NewFineTuneService creates a new FineTuneService.
func NewFineTuneService(
appConfig *config.ApplicationConfig,
modelLoader *model.ModelLoader,
configLoader *config.ModelConfigLoader,
) *FineTuneService {
s := &FineTuneService{
appConfig: appConfig,
modelLoader: modelLoader,
configLoader: configLoader,
jobs: make(map[string]*schema.FineTuneJob),
}
s.loadAllJobs()
return s
}
// fineTuneBaseDir returns the base directory for fine-tune job data.
func (s *FineTuneService) fineTuneBaseDir() string {
return filepath.Join(s.appConfig.DataPath, "fine-tune")
}
// jobDir returns the directory for a specific job.
func (s *FineTuneService) jobDir(jobID string) string {
return filepath.Join(s.fineTuneBaseDir(), jobID)
}
// saveJobState persists a job's state to disk as state.json.
func (s *FineTuneService) saveJobState(job *schema.FineTuneJob) {
dir := s.jobDir(job.ID)
if err := os.MkdirAll(dir, 0750); err != nil {
xlog.Error("Failed to create job directory", "job_id", job.ID, "error", err)
return
}
data, err := json.MarshalIndent(job, "", " ")
if err != nil {
xlog.Error("Failed to marshal job state", "job_id", job.ID, "error", err)
return
}
statePath := filepath.Join(dir, "state.json")
if err := os.WriteFile(statePath, data, 0640); err != nil {
xlog.Error("Failed to write job state", "job_id", job.ID, "error", err)
}
}
// loadAllJobs scans the fine-tune directory for persisted jobs and loads them.
func (s *FineTuneService) loadAllJobs() {
baseDir := s.fineTuneBaseDir()
entries, err := os.ReadDir(baseDir)
if err != nil {
// Directory doesn't exist yet — that's fine
return
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
statePath := filepath.Join(baseDir, entry.Name(), "state.json")
data, err := os.ReadFile(statePath)
if err != nil {
continue
}
var job schema.FineTuneJob
if err := json.Unmarshal(data, &job); err != nil {
xlog.Warn("Failed to parse job state", "path", statePath, "error", err)
continue
}
// Jobs that were running when we shut down are now stale
if job.Status == "queued" || job.Status == "loading_model" || job.Status == "loading_dataset" || job.Status == "training" || job.Status == "saving" {
job.Status = "stopped"
job.Message = "Server restarted while job was running"
}
// Exports that were in progress are now stale
if job.ExportStatus == "exporting" {
job.ExportStatus = "failed"
job.ExportMessage = "Server restarted while export was running"
}
s.jobs[job.ID] = &job
}
if len(s.jobs) > 0 {
xlog.Info("Loaded persisted fine-tune jobs", "count", len(s.jobs))
}
}
// StartJob starts a new fine-tuning job.
func (s *FineTuneService) StartJob(ctx context.Context, userID string, req schema.FineTuneJobRequest) (*schema.FineTuneJobResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
jobID := uuid.New().String()
backendName := req.Backend
if backendName == "" {
backendName = "trl"
}
// Always use DataPath for output — not user-configurable
outputDir := filepath.Join(s.fineTuneBaseDir(), jobID)
// Build gRPC request
grpcReq := &pb.FineTuneRequest{
Model: req.Model,
TrainingType: req.TrainingType,
TrainingMethod: req.TrainingMethod,
AdapterRank: req.AdapterRank,
AdapterAlpha: req.AdapterAlpha,
AdapterDropout: req.AdapterDropout,
TargetModules: req.TargetModules,
LearningRate: req.LearningRate,
NumEpochs: req.NumEpochs,
BatchSize: req.BatchSize,
GradientAccumulationSteps: req.GradientAccumulationSteps,
WarmupSteps: req.WarmupSteps,
MaxSteps: req.MaxSteps,
SaveSteps: req.SaveSteps,
WeightDecay: req.WeightDecay,
GradientCheckpointing: req.GradientCheckpointing,
Optimizer: req.Optimizer,
Seed: req.Seed,
MixedPrecision: req.MixedPrecision,
DatasetSource: req.DatasetSource,
DatasetSplit: req.DatasetSplit,
OutputDir: outputDir,
JobId: jobID,
ResumeFromCheckpoint: req.ResumeFromCheckpoint,
ExtraOptions: req.ExtraOptions,
}
// Serialize reward functions into extra_options for the backend
if len(req.RewardFunctions) > 0 {
rfJSON, err := json.Marshal(req.RewardFunctions)
if err != nil {
return nil, fmt.Errorf("failed to serialize reward functions: %w", err)
}
if grpcReq.ExtraOptions == nil {
grpcReq.ExtraOptions = make(map[string]string)
}
grpcReq.ExtraOptions["reward_funcs"] = string(rfJSON)
}
// Load the fine-tuning backend (per-job model ID so multiple jobs can run concurrently)
modelID := backendName + "-finetune-" + jobID
backendModel, err := s.modelLoader.Load(
model.WithBackendString(backendName),
model.WithModel(backendName),
model.WithModelID(modelID),
)
if err != nil {
return nil, fmt.Errorf("failed to load backend %s: %w", backendName, err)
}
// Start fine-tuning via gRPC
result, err := backendModel.StartFineTune(ctx, grpcReq)
if err != nil {
return nil, fmt.Errorf("failed to start fine-tuning: %w", err)
}
if !result.Success {
return nil, fmt.Errorf("fine-tuning failed to start: %s", result.Message)
}
// Track the job
job := &schema.FineTuneJob{
ID: jobID,
UserID: userID,
Model: req.Model,
Backend: backendName,
ModelID: modelID,
TrainingType: req.TrainingType,
TrainingMethod: req.TrainingMethod,
Status: "queued",
OutputDir: outputDir,
ExtraOptions: req.ExtraOptions,
CreatedAt: time.Now().UTC().Format(time.RFC3339),
Config: &req,
}
s.jobs[jobID] = job
s.saveJobState(job)
return &schema.FineTuneJobResponse{
ID: jobID,
Status: "queued",
Message: result.Message,
}, nil
}
// GetJob returns a fine-tuning job by ID.
func (s *FineTuneService) GetJob(userID, jobID string) (*schema.FineTuneJob, error) {
s.mu.Lock()
defer s.mu.Unlock()
job, ok := s.jobs[jobID]
if !ok {
return nil, fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
return nil, fmt.Errorf("job not found: %s", jobID)
}
return job, nil
}
// ListJobs returns all jobs for a user, sorted by creation time (newest first).
func (s *FineTuneService) ListJobs(userID string) []*schema.FineTuneJob {
s.mu.Lock()
defer s.mu.Unlock()
var result []*schema.FineTuneJob
for _, job := range s.jobs {
if userID == "" || job.UserID == userID {
result = append(result, job)
}
}
sort.Slice(result, func(i, j int) bool {
return result[i].CreatedAt > result[j].CreatedAt
})
return result
}
// StopJob stops a running fine-tuning job.
func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, saveCheckpoint bool) error {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
s.mu.Unlock()
// Kill the backend process directly
stopModelID := job.ModelID
if stopModelID == "" {
stopModelID = job.Backend + "-finetune"
}
s.modelLoader.ShutdownModel(stopModelID)
s.mu.Lock()
job.Status = "stopped"
job.Message = "Training stopped by user"
s.saveJobState(job)
s.mu.Unlock()
return nil
}
// DeleteJob removes a fine-tuning job and its associated data from disk.
func (s *FineTuneService) DeleteJob(userID, jobID string) error {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
// Reject deletion of actively running jobs
activeStatuses := map[string]bool{
"queued": true, "loading_model": true, "loading_dataset": true,
"training": true, "saving": true,
}
if activeStatuses[job.Status] {
s.mu.Unlock()
return fmt.Errorf("cannot delete job %s: currently %s (stop it first)", jobID, job.Status)
}
if job.ExportStatus == "exporting" {
s.mu.Unlock()
return fmt.Errorf("cannot delete job %s: export in progress", jobID)
}
exportModelName := job.ExportModelName
delete(s.jobs, jobID)
s.mu.Unlock()
// Remove job directory (state.json, checkpoints, output)
jobDir := s.jobDir(jobID)
if err := os.RemoveAll(jobDir); err != nil {
xlog.Warn("Failed to remove job directory", "job_id", jobID, "path", jobDir, "error", err)
}
// If an exported model exists, clean it up too
if exportModelName != "" {
modelsPath := s.appConfig.SystemState.Model.ModelsPath
modelDir := filepath.Join(modelsPath, exportModelName)
configPath := filepath.Join(modelsPath, exportModelName+".yaml")
if err := os.RemoveAll(modelDir); err != nil {
xlog.Warn("Failed to remove exported model directory", "path", modelDir, "error", err)
}
if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) {
xlog.Warn("Failed to remove exported model config", "path", configPath, "error", err)
}
// Reload model configs
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
xlog.Warn("Failed to reload configs after delete", "error", err)
}
}
xlog.Info("Deleted fine-tune job", "job_id", jobID)
return nil
}
// StreamProgress opens a gRPC progress stream and calls the callback for each update.
func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.FineTuneProgressEvent)) error {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
s.mu.Unlock()
streamModelID := job.ModelID
if streamModelID == "" {
streamModelID = job.Backend + "-finetune"
}
backendModel, err := s.modelLoader.Load(
model.WithBackendString(job.Backend),
model.WithModel(job.Backend),
model.WithModelID(streamModelID),
)
if err != nil {
return fmt.Errorf("failed to load backend: %w", err)
}
return backendModel.FineTuneProgress(ctx, &pb.FineTuneProgressRequest{
JobId: jobID,
}, func(update *pb.FineTuneProgressUpdate) {
// Update job status and persist
s.mu.Lock()
if j, ok := s.jobs[jobID]; ok {
// Don't let progress updates overwrite terminal states
isTerminal := j.Status == "stopped" || j.Status == "completed" || j.Status == "failed"
if !isTerminal {
j.Status = update.Status
}
if update.Message != "" {
j.Message = update.Message
}
s.saveJobState(j)
}
s.mu.Unlock()
// Convert extra metrics
extraMetrics := make(map[string]float32)
for k, v := range update.ExtraMetrics {
extraMetrics[k] = v
}
event := &schema.FineTuneProgressEvent{
JobID: update.JobId,
CurrentStep: update.CurrentStep,
TotalSteps: update.TotalSteps,
CurrentEpoch: update.CurrentEpoch,
TotalEpochs: update.TotalEpochs,
Loss: update.Loss,
LearningRate: update.LearningRate,
GradNorm: update.GradNorm,
EvalLoss: update.EvalLoss,
EtaSeconds: update.EtaSeconds,
ProgressPercent: update.ProgressPercent,
Status: update.Status,
Message: update.Message,
CheckpointPath: update.CheckpointPath,
SamplePath: update.SamplePath,
ExtraMetrics: extraMetrics,
}
callback(event)
})
}
// ListCheckpoints lists checkpoints for a job.
func (s *FineTuneService) ListCheckpoints(ctx context.Context, userID, jobID string) ([]*pb.CheckpointInfo, error) {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return nil, fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return nil, fmt.Errorf("job not found: %s", jobID)
}
s.mu.Unlock()
ckptModelID := job.ModelID
if ckptModelID == "" {
ckptModelID = job.Backend + "-finetune"
}
backendModel, err := s.modelLoader.Load(
model.WithBackendString(job.Backend),
model.WithModel(job.Backend),
model.WithModelID(ckptModelID),
)
if err != nil {
return nil, fmt.Errorf("failed to load backend: %w", err)
}
resp, err := backendModel.ListCheckpoints(ctx, &pb.ListCheckpointsRequest{
OutputDir: job.OutputDir,
})
if err != nil {
return nil, fmt.Errorf("failed to list checkpoints: %w", err)
}
return resp.Checkpoints, nil
}
// sanitizeModelName replaces non-alphanumeric characters with hyphens and lowercases.
func sanitizeModelName(s string) string {
re := regexp.MustCompile(`[^a-zA-Z0-9\-]`)
s = re.ReplaceAllString(s, "-")
s = regexp.MustCompile(`-+`).ReplaceAllString(s, "-")
s = strings.Trim(s, "-")
return strings.ToLower(s)
}
// ExportModel starts an async model export from a checkpoint and returns the intended model name immediately.
func (s *FineTuneService) ExportModel(ctx context.Context, userID, jobID string, req schema.ExportRequest) (string, error) {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return "", fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return "", fmt.Errorf("job not found: %s", jobID)
}
if job.ExportStatus == "exporting" {
s.mu.Unlock()
return "", fmt.Errorf("export already in progress for job %s", jobID)
}
s.mu.Unlock()
// Compute model name
modelName := req.Name
if modelName == "" {
base := sanitizeModelName(job.Model)
if base == "" {
base = "model"
}
shortID := jobID
if len(shortID) > 8 {
shortID = shortID[:8]
}
modelName = base + "-ft-" + shortID
}
// Compute output path in models directory
modelsPath := s.appConfig.SystemState.Model.ModelsPath
outputPath := filepath.Join(modelsPath, modelName)
// Check for name collision (synchronous — fast validation)
configPath := filepath.Join(modelsPath, modelName+".yaml")
if err := utils.VerifyPath(modelName+".yaml", modelsPath); err != nil {
return "", fmt.Errorf("invalid model name: %w", err)
}
if _, err := os.Stat(configPath); err == nil {
return "", fmt.Errorf("model %q already exists, choose a different name", modelName)
}
// Create output directory
if err := os.MkdirAll(outputPath, 0750); err != nil {
return "", fmt.Errorf("failed to create output directory: %w", err)
}
// Set export status to "exporting" and persist
s.mu.Lock()
job.ExportStatus = "exporting"
job.ExportMessage = ""
job.ExportModelName = ""
s.saveJobState(job)
s.mu.Unlock()
// Launch the export in a background goroutine
go func() {
s.setExportMessage(job, "Loading export backend...")
exportModelID := job.ModelID
if exportModelID == "" {
exportModelID = job.Backend + "-finetune"
}
backendModel, err := s.modelLoader.Load(
model.WithBackendString(job.Backend),
model.WithModel(job.Backend),
model.WithModelID(exportModelID),
)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("failed to load backend: %v", err))
return
}
// Merge job's extra_options (contains hf_token from training) with request's
mergedOpts := make(map[string]string)
for k, v := range job.ExtraOptions {
mergedOpts[k] = v
}
for k, v := range req.ExtraOptions {
mergedOpts[k] = v // request overrides job
}
grpcReq := &pb.ExportModelRequest{
CheckpointPath: req.CheckpointPath,
OutputPath: outputPath,
ExportFormat: req.ExportFormat,
QuantizationMethod: req.QuantizationMethod,
Model: req.Model,
ExtraOptions: mergedOpts,
}
s.setExportMessage(job, "Running model export (merging and converting — this may take a while)...")
result, err := backendModel.ExportModel(context.Background(), grpcReq)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("export failed: %v", err))
return
}
if !result.Success {
s.setExportFailed(job, fmt.Sprintf("export failed: %s", result.Message))
return
}
s.setExportMessage(job, "Export complete, generating model configuration...")
// Auto-import: detect format and generate config
cfg, err := importers.ImportLocalPath(outputPath, modelName)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("model exported to %s but config generation failed: %v", outputPath, err))
return
}
cfg.Name = modelName
// If base model not detected from files, use the job's model field
if cfg.Model == "" && job.Model != "" {
cfg.Model = job.Model
}
// Write YAML config
yamlData, err := yaml.Marshal(cfg)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("failed to marshal config: %v", err))
return
}
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
s.setExportFailed(job, fmt.Sprintf("failed to write config file: %v", err))
return
}
s.setExportMessage(job, "Registering model with LocalAI...")
// Reload configs so the model is immediately available
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
xlog.Warn("Failed to reload configs after export", "error", err)
}
if err := s.configLoader.Preload(modelsPath); err != nil {
xlog.Warn("Failed to preload after export", "error", err)
}
xlog.Info("Model exported and registered", "job_id", jobID, "model_name", modelName, "format", req.ExportFormat)
s.mu.Lock()
job.ExportStatus = "completed"
job.ExportModelName = modelName
job.ExportMessage = ""
s.saveJobState(job)
s.mu.Unlock()
}()
return modelName, nil
}
// setExportMessage updates the export message and persists the job state.
func (s *FineTuneService) setExportMessage(job *schema.FineTuneJob, msg string) {
s.mu.Lock()
job.ExportMessage = msg
s.saveJobState(job)
s.mu.Unlock()
}
// GetExportedModelPath returns the path to the exported model directory and its name.
func (s *FineTuneService) GetExportedModelPath(userID, jobID string) (string, string, error) {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return "", "", fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return "", "", fmt.Errorf("job not found: %s", jobID)
}
if job.ExportStatus != "completed" {
s.mu.Unlock()
return "", "", fmt.Errorf("export not completed for job %s (status: %s)", jobID, job.ExportStatus)
}
exportModelName := job.ExportModelName
s.mu.Unlock()
if exportModelName == "" {
return "", "", fmt.Errorf("no exported model name for job %s", jobID)
}
modelsPath := s.appConfig.SystemState.Model.ModelsPath
modelDir := filepath.Join(modelsPath, exportModelName)
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
return "", "", fmt.Errorf("exported model directory not found: %s", modelDir)
}
return modelDir, exportModelName, nil
}
// setExportFailed sets the export status to failed with a message.
func (s *FineTuneService) setExportFailed(job *schema.FineTuneJob, message string) {
xlog.Error("Export failed", "job_id", job.ID, "error", message)
s.mu.Lock()
job.ExportStatus = "failed"
job.ExportMessage = message
s.saveJobState(job)
s.mu.Unlock()
}
// UploadDataset handles dataset file upload and returns the local path.
func (s *FineTuneService) UploadDataset(filename string, data []byte) (string, error) {
uploadDir := filepath.Join(s.fineTuneBaseDir(), "datasets")
if err := os.MkdirAll(uploadDir, 0750); err != nil {
return "", fmt.Errorf("failed to create dataset directory: %w", err)
}
filePath := filepath.Join(uploadDir, uuid.New().String()[:8]+"-"+filename)
if err := os.WriteFile(filePath, data, 0640); err != nil {
return "", fmt.Errorf("failed to write dataset: %w", err)
}
return filePath, nil
}
// MarshalProgressEvent converts a progress event to JSON for SSE.
func MarshalProgressEvent(event *schema.FineTuneProgressEvent) (string, error) {
data, err := json.Marshal(event)
if err != nil {
return "", err
}
return string(data), nil
}

View File

@@ -0,0 +1,226 @@
+++
disableToc = false
title = "Fine-Tuning"
weight = 18
url = '/features/fine-tuning/'
+++
LocalAI supports fine-tuning LLMs directly through the API and Web UI. Fine-tuning is powered by pluggable backends that implement a generic gRPC interface, allowing support for different training frameworks and model types.
## Supported Backends
| Backend | Domain | GPU Required | Training Methods | Adapter Types |
|---------|--------|-------------|-----------------|---------------|
| **trl** | LLM fine-tuning | No (CPU or GPU) | SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO | LoRA, Full |
## Enabling Fine-Tuning
Fine-tuning is disabled by default. Enable it with:
```bash
LOCALAI_ENABLE_FINETUNING=true local-ai
```
When authentication is enabled, fine-tuning is a per-user feature (default OFF). Admins can enable it for specific users via the user management API.
## Quick Start
### 1. Start a fine-tuning job
```bash
curl -X POST http://localhost:8080/api/fine-tuning/jobs \
-H "Content-Type: application/json" \
-d '{
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"backend": "trl",
"training_method": "sft",
"training_type": "lora",
"dataset_source": "yahma/alpaca-cleaned",
"num_epochs": 1,
"batch_size": 2,
"learning_rate": 0.0002,
"adapter_rank": 16,
"adapter_alpha": 16,
"extra_options": {
"max_seq_length": "512"
}
}'
```
### 2. Monitor progress (SSE stream)
```bash
curl -N http://localhost:8080/api/fine-tuning/jobs/{job_id}/progress
```
### 3. List checkpoints
```bash
curl http://localhost:8080/api/fine-tuning/jobs/{job_id}/checkpoints
```
### 4. Export model
```bash
curl -X POST http://localhost:8080/api/fine-tuning/jobs/{job_id}/export \
-H "Content-Type: application/json" \
-d '{
"export_format": "gguf",
"quantization_method": "q4_k_m",
"output_path": "/models/my-finetuned-model"
}'
```
## API Reference
### Endpoints
| Method | Path | Description |
|--------|------|-------------|
| `POST` | `/api/fine-tuning/jobs` | Start a fine-tuning job |
| `GET` | `/api/fine-tuning/jobs` | List all jobs |
| `GET` | `/api/fine-tuning/jobs/:id` | Get job details |
| `DELETE` | `/api/fine-tuning/jobs/:id` | Stop a running job |
| `GET` | `/api/fine-tuning/jobs/:id/progress` | SSE progress stream |
| `GET` | `/api/fine-tuning/jobs/:id/checkpoints` | List checkpoints |
| `POST` | `/api/fine-tuning/jobs/:id/export` | Export model |
| `POST` | `/api/fine-tuning/datasets` | Upload dataset file |
### Job Request Fields
| Field | Type | Description |
|-------|------|-------------|
| `model` | string | HuggingFace model ID or local path (required) |
| `backend` | string | Backend name (default: `trl`) |
| `training_method` | string | `sft`, `dpo`, `grpo`, `rloo`, `reward`, `kto`, `orpo` |
| `training_type` | string | `lora` or `full` |
| `dataset_source` | string | HuggingFace dataset ID or local file path (required) |
| `adapter_rank` | int | LoRA rank (default: 16) |
| `adapter_alpha` | int | LoRA alpha (default: 16) |
| `num_epochs` | int | Number of training epochs (default: 3) |
| `batch_size` | int | Per-device batch size (default: 2) |
| `learning_rate` | float | Learning rate (default: 2e-4) |
| `gradient_accumulation_steps` | int | Gradient accumulation (default: 4) |
| `warmup_steps` | int | Warmup steps (default: 5) |
| `optimizer` | string | `adamw_torch`, `adamw_8bit`, `sgd`, `adafactor`, `prodigy` |
| `extra_options` | map | Backend-specific options (see below) |
### Backend-Specific Options (`extra_options`)
#### TRL
| Key | Description | Default |
|-----|-------------|---------|
| `max_seq_length` | Maximum sequence length | `512` |
| `packing` | Enable sequence packing | `false` |
| `trust_remote_code` | Trust remote code in model | `false` |
| `load_in_4bit` | Enable 4-bit quantization (GPU only) | `false` |
#### DPO-specific (training_method=dpo)
| Key | Description | Default |
|-----|-------------|---------|
| `beta` | KL penalty coefficient | `0.1` |
| `loss_type` | Loss type: `sigmoid`, `hinge`, `ipo` | `sigmoid` |
| `max_length` | Maximum sequence length | `512` |
#### GRPO-specific (training_method=grpo)
| Key | Description | Default |
|-----|-------------|---------|
| `num_generations` | Number of generations per prompt | `4` |
| `max_completion_length` | Max completion token length | `256` |
### GRPO Reward Functions
GRPO training requires reward functions to evaluate model completions. Specify them via the `reward_functions` field (a typed array) or via `extra_options["reward_funcs"]` (a JSON string).
#### Built-in Reward Functions
| Name | Description | Parameters |
|------|-------------|-----------|
| `format_reward` | Checks `<think>...</think>` then answer format (1.0/0.0) | — |
| `reasoning_accuracy_reward` | Extracts `<answer>` content, compares to dataset's `answer` column | — |
| `length_reward` | Score based on proximity to target length [0, 1] | `target_length` (default: 200) |
| `xml_tag_reward` | Scores properly opened/closed `<think>` and `<answer>` tags | — |
| `no_repetition_reward` | Penalizes n-gram repetition [0, 1] | — |
| `code_execution_reward` | Checks Python code block syntax validity (1.0/0.0) | — |
#### Inline Custom Reward Functions
You can provide custom reward function code as a Python function body. The function receives `completions` (list of strings) and `**kwargs`, and must return `list[float]`.
**Security restrictions for inline code:**
- Allowed builtins: `len`, `int`, `float`, `str`, `list`, `dict`, `range`, `enumerate`, `zip`, `map`, `filter`, `sorted`, `min`, `max`, `sum`, `abs`, `round`, `any`, `all`, `isinstance`, `print`, `True`, `False`, `None`
- Available modules: `re`, `math`, `json`, `string`
- Blocked: `open`, `__import__`, `exec`, `eval`, `compile`, `os`, `subprocess`, `getattr`, `setattr`, `delattr`, `globals`, `locals`
- Functions are compiled and validated at job start (fail-fast on syntax errors)
#### Example API Request
```bash
curl -X POST http://localhost:8080/api/fine-tuning/jobs \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2.5-1.5B-Instruct",
"backend": "trl",
"training_method": "grpo",
"training_type": "lora",
"dataset_source": "my-reasoning-dataset",
"num_epochs": 1,
"batch_size": 2,
"learning_rate": 5e-6,
"reward_functions": [
{"type": "builtin", "name": "reasoning_accuracy_reward"},
{"type": "builtin", "name": "format_reward"},
{"type": "builtin", "name": "length_reward", "params": {"target_length": "200"}},
{"type": "inline", "name": "think_presence", "code": "return [1.0 if \"<think>\" in c else 0.0 for c in completions]"}
],
"extra_options": {
"num_generations": "4",
"max_completion_length": "256"
}
}'
```
### Export Formats
| Format | Description | Notes |
|--------|-------------|-------|
| `lora` | LoRA adapter files | Smallest, requires base model |
| `merged_16bit` | Full model in 16-bit | Large but standalone |
| `merged_4bit` | Full model in 4-bit | Smaller, standalone |
| `gguf` | GGUF format | For llama.cpp, requires `quantization_method` |
### GGUF Quantization Methods
`q4_k_m`, `q5_k_m`, `q8_0`, `f16`, `q4_0`, `q5_0`
## Web UI
When fine-tuning is enabled, a "Fine-Tune" page appears in the sidebar under the Agents section. The UI provides:
1. **Job Configuration** — Select backend, model, training method, adapter type, and hyperparameters
2. **Dataset Upload** — Upload local datasets or reference HuggingFace datasets
3. **Training Monitor** — Real-time loss chart, progress bar, metrics display
4. **Export** — Export trained models in various formats
## Dataset Formats
Datasets should follow standard HuggingFace formats:
- **SFT**: Alpaca format (`instruction`, `input`, `output` fields) or ChatML/ShareGPT
- **DPO**: Preference pairs (`prompt`, `chosen`, `rejected` fields)
- **GRPO**: Prompts with reward signals
Supported file formats: `.json`, `.jsonl`, `.csv`
## Architecture
Fine-tuning uses the same gRPC backend architecture as inference:
1. **Proto layer**: `FineTuneRequest`, `FineTuneProgress` (streaming), `StopFineTune`, `ListCheckpoints`, `ExportModel`
2. **Python backends**: Each backend implements the gRPC interface with its specific training framework
3. **Go service**: Manages job lifecycle, routes API requests to backends
4. **REST API**: HTTP endpoints with SSE progress streaming
5. **React UI**: Configuration form, real-time training monitor, export panel

12
go.mod
View File

@@ -67,10 +67,18 @@ require (
)
require (
github.com/chasefleming/elem-go v0.30.0 // indirect
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/gofiber/template v1.8.3 // indirect
github.com/gofiber/template/html/v2 v2.1.3 // indirect
github.com/gofiber/utils v1.1.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/spf13/cobra v1.10.2 // indirect
github.com/spf13/pflag v1.0.9 // indirect
github.com/stretchr/testify v1.11.1 // indirect
github.com/tmc/langchaingo v0.1.14 // indirect
)
@@ -136,8 +144,8 @@ require (
github.com/kevinburke/ssh_config v1.2.0 // indirect
github.com/labstack/gommon v0.4.2 // indirect
github.com/mschoch/smat v0.2.0 // indirect
github.com/mudler/LocalAGI v0.0.0-20260319174513-43c65ec7e88a
github.com/mudler/localrecall v0.5.9-0.20260319170742-933f68603f62 // indirect
github.com/mudler/LocalAGI v0.0.0-20260321004723-b485b77037c4
github.com/mudler/localrecall v0.5.9-0.20260321005011-810084e9369b // indirect
github.com/mudler/skillserver v0.0.5
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect

24
go.sum
View File

@@ -148,6 +148,8 @@ github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf h1:rLG0Y
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/chasefleming/elem-go v0.30.0 h1:BlhV1ekv1RbFiM8XZUQeln1Ikb4D+bu2eDO4agREvok=
github.com/chasefleming/elem-go v0.30.0/go.mod h1:hz73qILBIKnTgOujnSMtEj20/epI+f6vg71RUilJAA4=
github.com/chengxilo/virtualterm v1.0.4 h1:Z6IpERbRVlfB8WkOmtbHiDbBANU7cimRIof7mk9/PwM=
github.com/chengxilo/virtualterm v1.0.4/go.mod h1:DyxxBZz/x1iqJjFxTFcr6/x+jSpqN0iwWCOK1q10rlY=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
@@ -177,6 +179,7 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creachadair/mds v0.21.3 h1:RRgEAPIb52cU0q7UxGyN+13QlCVTZIL4slRr0cYYQfA=
github.com/creachadair/mds v0.21.3/go.mod h1:1ltMWZd9yXhaHEoZwBialMaviWVUpRPvMwVP7saFAzM=
github.com/creachadair/otp v0.5.0 h1:q3Th7CXm2zlmCdBjw5tEPFOj4oWJMnVL5HXlq0sNKS0=
@@ -185,6 +188,8 @@ github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/cyphar/filepath-securejoin v0.5.1 h1:eYgfMq5yryL4fbWfkLpFFy2ukSELzaJOTaUTuh+oF48=
github.com/cyphar/filepath-securejoin v0.5.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI=
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 h1:flLYmnQFZNo04x2NPehMbf30m7Pli57xwZ0NFqR/hb0=
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2/go.mod h1:NtWqRzAp/1tw+twkW8uuBenEVVYndEAZACWU3F3xdoQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
@@ -341,6 +346,12 @@ github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw=
github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
github.com/gofiber/template v1.8.3 h1:hzHdvMwMo/T2kouz2pPCA0zGiLCeMnoGsQZBTSYgZxc=
github.com/gofiber/template v1.8.3/go.mod h1:bs/2n0pSNPOkRa5VJ8zTIvedcI/lEYxzV3+YPXdBvq8=
github.com/gofiber/template/html/v2 v2.1.3 h1:n1LYBtmr9C0V/k/3qBblXyMxV5B0o/gpb6dFLp8ea+o=
github.com/gofiber/template/html/v2 v2.1.3/go.mod h1:U5Fxgc5KpyujU9OqKzy6Kn6Qup6Tm7zdsISR+VpnHRE=
github.com/gofiber/utils v1.1.0 h1:vdEBpn7AzIUJRhe+CiTOJdUcTg4Q9RK+pEa0KPbLdrM=
github.com/gofiber/utils v1.1.0/go.mod h1:poZpsnhBykfnY1Mc0KeEa6mSHrS3dV0+oBWyeQmb2e0=
github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw=
github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
@@ -445,6 +456,8 @@ github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc=
github.com/huin/goupnp v1.3.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/ipfs/boxo v0.30.0 h1:7afsoxPGGqfoH7Dum/wOTGUB9M5fb8HyKPMlLfBvIEQ=
github.com/ipfs/boxo v0.30.0/go.mod h1:BPqgGGyHB9rZZcPSzah2Dc9C+5Or3U1aQe7EH1H7370=
github.com/ipfs/go-block-format v0.2.0 h1:ZqrkxBA2ICbDRbK8KJs/u0O3dlp6gmAuuXUJNiW1Ycs=
@@ -666,6 +679,8 @@ github.com/mschoch/smat v0.2.0 h1:8imxQsjDm8yFEAVBe7azKmKSgzSkZXDuKkSq9374khM=
github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw=
github.com/mudler/LocalAGI v0.0.0-20260319174513-43c65ec7e88a h1:combrnE/eLPnUhqrYmtFmqEfR6x9xS+HoTFdnMozvik=
github.com/mudler/LocalAGI v0.0.0-20260319174513-43c65ec7e88a/go.mod h1:AbBcAE9JqkexN4aG8rYQn5LzmzffWqcMvQ+Nlvin3WI=
github.com/mudler/LocalAGI v0.0.0-20260321004723-b485b77037c4 h1:zWrAdAI/gwAPwXQAJuFLF8vvJdsxpxjKiBiC0EzhLOo=
github.com/mudler/LocalAGI v0.0.0-20260321004723-b485b77037c4/go.mod h1:g+6CD5tP4a+rRW20CrMpE/JDazq5N4n4YDxIT7tT1mY=
github.com/mudler/cogito v0.9.5-0.20260315222927-63abdec7189b h1:A74T2Lauvg61KodYqsjTYDY05kPLcW+efVZjd23dghU=
github.com/mudler/cogito v0.9.5-0.20260315222927-63abdec7189b/go.mod h1:6sfja3lcu2nWRzEc0wwqGNu/eCG3EWgij+8s7xyUeQ4=
github.com/mudler/edgevpn v0.31.1 h1:7qegiDWd0kAg6ljhNHxqvp8hbo/6BbzSdbb7/2WZfiY=
@@ -676,6 +691,10 @@ github.com/mudler/go-processmanager v0.1.0 h1:fcSKgF9U/a1Z7KofAFeZnke5YseadCI5Gq
github.com/mudler/go-processmanager v0.1.0/go.mod h1:h6kmHUZeafr+k5hRYpGLMzJFH4hItHffgpRo2QIkP+o=
github.com/mudler/localrecall v0.5.9-0.20260319170742-933f68603f62 h1:KVTEukvLlQXKZx1C1ZLru+ahaiECLF+7v2caK8vauJ0=
github.com/mudler/localrecall v0.5.9-0.20260319170742-933f68603f62/go.mod h1:/d2bG9H8G/HzsnXTTQl2bOD+ui74XwpeiSDJ+2gdkGc=
github.com/mudler/localrecall v0.5.9-0.20260321003356-422f3b1fff45 h1:+zTrbYk70wHrtvpsO2k7gMPvHYnWYCnXNxAtMex+7yg=
github.com/mudler/localrecall v0.5.9-0.20260321003356-422f3b1fff45/go.mod h1:/d2bG9H8G/HzsnXTTQl2bOD+ui74XwpeiSDJ+2gdkGc=
github.com/mudler/localrecall v0.5.9-0.20260321005011-810084e9369b h1:XeAnOEOOSKMfS5XNGpRTltQgjKCinho0V4uAhrgxN7Q=
github.com/mudler/localrecall v0.5.9-0.20260321005011-810084e9369b/go.mod h1:xuPtgL9zUyiQLmspYzO3kaboYrGbWmwi8BQPt1aCAcs=
github.com/mudler/memory v0.0.0-20251216220809-d1256471a6c2 h1:+WHsL/j6EWOMUiMVIOJNKOwSKiQt/qDPc9fePCf87fA=
github.com/mudler/memory v0.0.0-20251216220809-d1256471a6c2/go.mod h1:EA8Ashhd56o32qN7ouPKFSRUs/Z+LrRCF4v6R2Oarm8=
github.com/mudler/skillserver v0.0.5 h1:t6HPpeSX8kEP7B8F5GXoQUam5VEYNmJuG6oy2/vdTu8=
@@ -855,6 +874,7 @@ github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR
github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww=
github.com/russross/blackfriday v1.6.0/go.mod h1:ti0ldHuxg49ri4ksnFxlkCfN+hvslNlmVHqNRXXJNAY=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU=
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
@@ -928,6 +948,10 @@ github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0b
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=

View File

@@ -63,4 +63,11 @@ type Backend interface {
AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error)
ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error)
// Fine-tuning
StartFineTune(ctx context.Context, in *pb.FineTuneRequest, opts ...grpc.CallOption) (*pb.FineTuneJobResult, error)
FineTuneProgress(ctx context.Context, in *pb.FineTuneProgressRequest, f func(update *pb.FineTuneProgressUpdate), opts ...grpc.CallOption) error
StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error)
ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error)
ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error)
}

View File

@@ -120,6 +120,26 @@ func (llm *Base) AudioDecode(*pb.AudioDecodeRequest) (*pb.AudioDecodeResult, err
return nil, fmt.Errorf("unimplemented")
}
func (llm *Base) StartFineTune(*pb.FineTuneRequest) (*pb.FineTuneJobResult, error) {
return nil, fmt.Errorf("unimplemented")
}
func (llm *Base) FineTuneProgress(*pb.FineTuneProgressRequest, chan *pb.FineTuneProgressUpdate) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) StopFineTune(*pb.FineTuneStopRequest) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) ListCheckpoints(*pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error) {
return nil, fmt.Errorf("unimplemented")
}
func (llm *Base) ExportModel(*pb.ExportModelRequest) error {
return fmt.Errorf("unimplemented")
}
func memoryUsage() *pb.MemoryUsageData {
mud := pb.MemoryUsageData{
Breakdown: make(map[string]uint64),

View File

@@ -632,6 +632,142 @@ func (c *Client) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opt
return client.AudioDecode(ctx, in, opts...)
}
func (c *Client) StartFineTune(ctx context.Context, in *pb.FineTuneRequest, opts ...grpc.CallOption) (*pb.FineTuneJobResult, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.StartFineTune(ctx, in, opts...)
}
func (c *Client) FineTuneProgress(ctx context.Context, in *pb.FineTuneProgressRequest, f func(update *pb.FineTuneProgressUpdate), opts ...grpc.CallOption) error {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
))
if err != nil {
return err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
stream, err := client.FineTuneProgress(ctx, in, opts...)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
update, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
f(update)
}
return nil
}
func (c *Client) StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.StopFineTune(ctx, in, opts...)
}
func (c *Client) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.ListCheckpoints(ctx, in, opts...)
}
func (c *Client) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.ExportModel(ctx, in, opts...)
}
func (c *Client) ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) {
if !c.parallel {
c.opMutex.Lock()

View File

@@ -123,6 +123,68 @@ func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsReques
return e.s.GetMetrics(ctx, in)
}
func (e *embedBackend) StartFineTune(ctx context.Context, in *pb.FineTuneRequest, opts ...grpc.CallOption) (*pb.FineTuneJobResult, error) {
return e.s.StartFineTune(ctx, in)
}
func (e *embedBackend) FineTuneProgress(ctx context.Context, in *pb.FineTuneProgressRequest, f func(update *pb.FineTuneProgressUpdate), opts ...grpc.CallOption) error {
bs := &embedBackendFineTuneProgressStream{
ctx: ctx,
fn: f,
}
return e.s.FineTuneProgress(in, bs)
}
func (e *embedBackend) StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error) {
return e.s.StopFineTune(ctx, in)
}
func (e *embedBackend) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error) {
return e.s.ListCheckpoints(ctx, in)
}
func (e *embedBackend) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error) {
return e.s.ExportModel(ctx, in)
}
var _ pb.Backend_FineTuneProgressServer = new(embedBackendFineTuneProgressStream)
type embedBackendFineTuneProgressStream struct {
ctx context.Context
fn func(update *pb.FineTuneProgressUpdate)
}
func (e *embedBackendFineTuneProgressStream) Send(update *pb.FineTuneProgressUpdate) error {
e.fn(update)
return nil
}
func (e *embedBackendFineTuneProgressStream) SetHeader(md metadata.MD) error {
return nil
}
func (e *embedBackendFineTuneProgressStream) SendHeader(md metadata.MD) error {
return nil
}
func (e *embedBackendFineTuneProgressStream) SetTrailer(md metadata.MD) {
}
func (e *embedBackendFineTuneProgressStream) Context() context.Context {
return e.ctx
}
func (e *embedBackendFineTuneProgressStream) SendMsg(m any) error {
if x, ok := m.(*pb.FineTuneProgressUpdate); ok {
return e.Send(x)
}
return nil
}
func (e *embedBackendFineTuneProgressStream) RecvMsg(m any) error {
return nil
}
type embedBackendServerStream struct {
ctx context.Context
fn func(reply *pb.Reply)

View File

@@ -35,6 +35,13 @@ type AIModel interface {
AudioDecode(*pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error)
ModelMetadata(*pb.ModelOptions) (*pb.ModelMetadataResponse, error)
// Fine-tuning
StartFineTune(*pb.FineTuneRequest) (*pb.FineTuneJobResult, error)
FineTuneProgress(*pb.FineTuneProgressRequest, chan *pb.FineTuneProgressUpdate) error
StopFineTune(*pb.FineTuneStopRequest) error
ListCheckpoints(*pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error)
ExportModel(*pb.ExportModelRequest) error
}
func newReply(s string) *pb.Reply {

View File

@@ -308,6 +308,75 @@ func (s *server) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest) (*p
return res, nil
}
func (s *server) StartFineTune(ctx context.Context, in *pb.FineTuneRequest) (*pb.FineTuneJobResult, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.StartFineTune(in)
if err != nil {
return &pb.FineTuneJobResult{Success: false, Message: fmt.Sprintf("Error starting fine-tune: %s", err.Error())}, err
}
return res, nil
}
func (s *server) FineTuneProgress(in *pb.FineTuneProgressRequest, stream pb.Backend_FineTuneProgressServer) error {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
updateChan := make(chan *pb.FineTuneProgressUpdate)
done := make(chan bool)
go func() {
for update := range updateChan {
stream.Send(update)
}
done <- true
}()
err := s.llm.FineTuneProgress(in, updateChan)
<-done
return err
}
func (s *server) StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.StopFineTune(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error stopping fine-tune: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Fine-tune stopped", Success: true}, nil
}
func (s *server) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.ListCheckpoints(in)
if err != nil {
return nil, err
}
return res, nil
}
func (s *server) ExportModel(ctx context.Context, in *pb.ExportModelRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.ExportModel(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error exporting model: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Model exported", Success: true}, nil
}
func (s *server) ModelMetadata(ctx context.Context, in *pb.ModelOptions) (*pb.ModelMetadataResponse, error) {
if s.llm.Locking() {
s.llm.Lock()