From d9c1db2b874d12a60f489a3504393418b0b764cb Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 21 Mar 2026 02:08:02 +0100 Subject: [PATCH] feat: add (experimental) fine-tuning support with TRL (#9088) * feat: add fine-tuning endpoint Signed-off-by: Ettore Di Giacinto * feat(experimental): add fine-tuning endpoint and TRL support This changeset defines new GRPC signatues for Fine tuning backends, and add TRL backend as initial fine-tuning engine. This implementation also supports exporting to GGUF and automatically importing it to LocalAI after fine-tuning. Signed-off-by: Ettore Di Giacinto * commit TRL backend, stop by killing process Signed-off-by: Ettore Di Giacinto * move fine-tune to generic features Signed-off-by: Ettore Di Giacinto * add evals, reorder menu Signed-off-by: Ettore Di Giacinto * Fix tests Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- .agents/debugging-backends.md | 141 ++ .github/workflows/backend.yml | 39 + AGENTS.md | 1 + Makefile | 8 +- backend/backend.proto | 109 ++ backend/index.yaml | 51 + backend/python/trl/Makefile | 26 + backend/python/trl/backend.py | 860 ++++++++++ backend/python/trl/install.sh | 37 + backend/python/trl/requirements-cpu.txt | 9 + backend/python/trl/requirements-cublas12.txt | 9 + backend/python/trl/requirements-cublas13.txt | 9 + backend/python/trl/requirements.txt | 3 + backend/python/trl/reward_functions.py | 236 +++ backend/python/trl/run.sh | 10 + backend/python/trl/test.py | 58 + backend/python/trl/test.sh | 11 + core/cli/run.go | 8 + core/config/application_config.go | 14 + core/gallery/importers/local.go | 205 +++ core/gallery/importers/local_test.go | 148 ++ core/http/app.go | 11 + core/http/auth/features.go | 19 + core/http/auth/permissions.go | 8 +- .../endpoints/localai/agent_collections.go | 5 +- core/http/endpoints/localai/finetune.go | 362 ++++ core/http/react-ui/src/App.css | 75 +- .../src/components/LoadingSpinner.jsx | 16 +- core/http/react-ui/src/components/Sidebar.jsx | 208 ++- core/http/react-ui/src/pages/FineTune.jsx | 1525 +++++++++++++++++ core/http/react-ui/src/pages/Studio.jsx | 48 + core/http/react-ui/src/pages/Users.jsx | 35 +- core/http/react-ui/src/router.jsx | 5 + core/http/react-ui/src/utils/api.js | 19 + core/http/routes/auth.go | 7 +- core/http/routes/finetuning.go | 42 + core/http/routes/localai.go | 5 +- core/schema/finetune.go | 111 ++ core/services/agent_pool.go | 6 +- core/services/finetune.go | 700 ++++++++ docs/content/features/fine-tuning.md | 226 +++ go.mod | 12 +- go.sum | 24 + pkg/grpc/backend.go | 7 + pkg/grpc/base/base.go | 20 + pkg/grpc/client.go | 136 ++ pkg/grpc/embed.go | 62 + pkg/grpc/interface.go | 7 + pkg/grpc/server.go | 69 + 49 files changed, 5652 insertions(+), 110 deletions(-) create mode 100644 .agents/debugging-backends.md create mode 100644 backend/python/trl/Makefile create mode 100644 backend/python/trl/backend.py create mode 100644 backend/python/trl/install.sh create mode 100644 backend/python/trl/requirements-cpu.txt create mode 100644 backend/python/trl/requirements-cublas12.txt create mode 100644 backend/python/trl/requirements-cublas13.txt create mode 100644 backend/python/trl/requirements.txt create mode 100644 backend/python/trl/reward_functions.py create mode 100644 backend/python/trl/run.sh create mode 100644 backend/python/trl/test.py create mode 100644 backend/python/trl/test.sh create mode 100644 core/gallery/importers/local.go create mode 100644 core/gallery/importers/local_test.go create mode 100644 core/http/endpoints/localai/finetune.go create mode 100644 core/http/react-ui/src/pages/FineTune.jsx create mode 100644 core/http/react-ui/src/pages/Studio.jsx create mode 100644 core/http/routes/finetuning.go create mode 100644 core/schema/finetune.go create mode 100644 core/services/finetune.go create mode 100644 docs/content/features/fine-tuning.md diff --git a/.agents/debugging-backends.md b/.agents/debugging-backends.md new file mode 100644 index 000000000..e818753c2 --- /dev/null +++ b/.agents/debugging-backends.md @@ -0,0 +1,141 @@ +# Debugging and Rebuilding Backends + +When a backend fails at runtime (e.g. a gRPC method error, a Python import error, or a dependency conflict), use this guide to diagnose, fix, and rebuild. + +## Architecture Overview + +- **Source directory**: `backend/python//` (or `backend/go//`, `backend/cpp//`) +- **Installed directory**: `backends//` — this is what LocalAI actually runs. It is populated by `make backends/` which builds a Docker image, exports it, and installs it via `local-ai backends install`. +- **Virtual environment**: `backends//venv/` — the installed Python venv (for Python backends). The Python binary is at `backends//venv/bin/python`. + +Editing files in `backend/python//` does **not** affect the running backend until you rebuild with `make backends/`. + +## Diagnosing Failures + +### 1. Check the logs + +Backend gRPC processes log to LocalAI's stdout/stderr. Look for lines tagged with the backend's model ID: + +``` +GRPC stderr id="trl-finetune-127.0.0.1:37335" line="..." +``` + +Common error patterns: +- **"Method not implemented"** — the backend is missing a gRPC method that the Go side calls. The model loader (`pkg/model/initializers.go`) always calls `LoadModel` after `Health`; fine-tuning backends must implement it even as a no-op stub. +- **Python import errors / `AttributeError`** — usually a dependency version mismatch (e.g. `pyarrow` removing `PyExtensionType`). +- **"failed to load backend"** — the gRPC process crashed or never started. Check stderr lines for the traceback. + +### 2. Test the Python environment directly + +You can run the installed venv's Python to check imports without starting the full server: + +```bash +backends//venv/bin/python -c "import datasets; print(datasets.__version__)" +``` + +If `pip` is missing from the venv, bootstrap it: + +```bash +backends//venv/bin/python -m ensurepip +``` + +Then use `backends//venv/bin/python -m pip install ...` to test fixes in the installed venv before committing them to the source requirements. + +### 3. Check upstream dependency constraints + +When you hit a dependency conflict, check what the main library expects. For example, TRL's upstream `requirements.txt`: + +``` +https://github.com/huggingface/trl/blob/main/requirements.txt +``` + +Pin minimum versions in the backend's requirements files to match upstream. + +## Common Fixes + +### Missing gRPC methods + +If the Go side calls a method the backend doesn't implement (e.g. `LoadModel`), add a no-op stub in `backend.py`: + +```python +def LoadModel(self, request, context): + """No-op — actual loading happens elsewhere.""" + return backend_pb2.Result(success=True, message="OK") +``` + +The gRPC contract requires `LoadModel` to succeed for the model loader to return a usable client, even if the backend doesn't need upfront model loading. + +### Dependency version conflicts + +Python backends often break when a transitive dependency releases a breaking change (e.g. `pyarrow` removing `PyExtensionType`). Steps: + +1. Identify the broken import in the logs +2. Test in the installed venv: `backends//venv/bin/python -c "import "` +3. Check upstream requirements for version constraints +4. Update **all** requirements files in `backend/python//`: + - `requirements.txt` — base deps (grpcio, protobuf) + - `requirements-cpu.txt` — CPU-specific (includes PyTorch CPU index) + - `requirements-cublas12.txt` — CUDA 12 + - `requirements-cublas13.txt` — CUDA 13 +5. Rebuild: `make backends/` + +### PyTorch index conflicts (uv resolver) + +The Docker build uses `uv` for pip installs. When `--extra-index-url` points to the PyTorch wheel index, `uv` may refuse to fetch packages like `requests` from PyPI if it finds a different version on the PyTorch index first. Fix this by adding `--index-strategy=unsafe-first-match` to `install.sh`: + +```bash +EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" +installRequirements +``` + +Most Python backends already do this — check `backend/python/transformers/install.sh` or similar for reference. + +## Rebuilding + +### Rebuild a single backend + +```bash +make backends/ +``` + +This runs the Docker build (`Dockerfile.python`), exports the image to `backend-images/.tar`, and installs it into `backends//`. It also rebuilds the `local-ai` Go binary (without extra tags). + +**Important**: If you were previously running with `GO_TAGS=auth`, the `make backends/` step will overwrite your binary without that tag. Rebuild the Go binary afterward: + +```bash +GO_TAGS=auth make build +``` + +### Rebuild and restart + +After rebuilding a backend, you must restart LocalAI for it to pick up the new backend files. The backend gRPC process is spawned on demand when the model is first loaded. + +```bash +# Kill existing process +kill + +# Restart +./local-ai run --debug [your flags] +``` + +### Quick iteration (skip Docker rebuild) + +For fast iteration on a Python backend's `backend.py` without a full Docker rebuild, you can edit the installed copy directly: + +```bash +# Edit the installed copy +vim backends//backend.py + +# Restart LocalAI to respawn the gRPC process +``` + +This is useful for testing but **does not persist** — the next `make backends/` will overwrite it. Always commit fixes to the source in `backend/python//`. + +## Verification + +After fixing and rebuilding: + +1. Start LocalAI and confirm the backend registers: look for `Registering backend name=""` in the logs +2. Trigger the operation that failed (e.g. start a fine-tuning job) +3. Watch the GRPC stderr/stdout lines for the backend's model ID +4. Confirm no errors in the traceback diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 50cc9b180..6842d7da1 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -118,6 +118,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-cpu-trl' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'true' + backend: "trl" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" @@ -366,6 +379,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "12" + cuda-minor-version: "8" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-12-trl' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "trl" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" @@ -757,6 +783,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-13-trl' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "trl" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" diff --git a/AGENTS.md b/AGENTS.md index 41a89eab6..27785bf82 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -12,6 +12,7 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read | [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing | | [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI | | [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control | +| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends | ## Quick Reference diff --git a/Makefile b/Makefile index c429097b6..e3c02e9c2 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Disable parallel execution for backend builds -.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus +.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl GOCMD=go GOTEST=$(GOCMD) test @@ -421,6 +421,7 @@ prepare-test-extra: protogen-python $(MAKE) -C backend/python/voxcpm $(MAKE) -C backend/python/whisperx $(MAKE) -C backend/python/ace-step + $(MAKE) -C backend/python/trl test-extra: prepare-test-extra $(MAKE) -C backend/python/transformers test @@ -440,6 +441,7 @@ test-extra: prepare-test-extra $(MAKE) -C backend/python/voxcpm test $(MAKE) -C backend/python/whisperx test $(MAKE) -C backend/python/ace-step test + $(MAKE) -C backend/python/trl test DOCKER_IMAGE?=local-ai IMAGE_TYPE?=core @@ -572,6 +574,7 @@ BACKEND_VOXCPM = voxcpm|python|.|false|true BACKEND_WHISPERX = whisperx|python|.|false|true BACKEND_ACE_STEP = ace-step|python|.|false|true BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true +BACKEND_TRL = trl|python|.|false|true # Helper function to build docker image for a backend # Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG) @@ -629,12 +632,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX))) $(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP))) $(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP))) $(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED))) +$(eval $(call generate-docker-build-target,$(BACKEND_TRL))) # Pattern rule for docker-save targets docker-save-%: backend-images docker save local-ai-backend:$* -o backend-images/$*.tar -docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed +docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl ######################################################## ### Mock Backend for E2E Tests diff --git a/backend/backend.proto b/backend/backend.proto index d8a0cd9fc..91497c523 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -39,6 +39,13 @@ service Backend { rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {} rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {} + + // Fine-tuning RPCs + rpc StartFineTune(FineTuneRequest) returns (FineTuneJobResult) {} + rpc FineTuneProgress(FineTuneProgressRequest) returns (stream FineTuneProgressUpdate) {} + rpc StopFineTune(FineTuneStopRequest) returns (Result) {} + rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {} + rpc ExportModel(ExportModelRequest) returns (Result) {} } // Define the empty request @@ -528,3 +535,105 @@ message ModelMetadataResponse { string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable) ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis } + +// Fine-tuning messages + +message FineTuneRequest { + // Model identification + string model = 1; // HF model name or local path + string training_type = 2; // "lora", "loha", "lokr", "full" — what parameters to train + string training_method = 3; // "sft", "dpo", "grpo", "rloo", "reward", "kto", "orpo", "network_training" + + // Adapter config (universal across LoRA/LoHa/LoKr for LLM + diffusion) + int32 adapter_rank = 10; // LoRA rank (r), default 16 + int32 adapter_alpha = 11; // scaling factor, default 16 + float adapter_dropout = 12; // default 0.0 + repeated string target_modules = 13; // layer names to adapt + + // Universal training hyperparameters + float learning_rate = 20; // default 2e-4 + int32 num_epochs = 21; // default 3 + int32 batch_size = 22; // default 2 + int32 gradient_accumulation_steps = 23; // default 4 + int32 warmup_steps = 24; // default 5 + int32 max_steps = 25; // 0 = use epochs + int32 save_steps = 26; // 0 = only save final + float weight_decay = 27; // default 0.01 + bool gradient_checkpointing = 28; + string optimizer = 29; // adamw_8bit, adamw, sgd, adafactor, prodigy + int32 seed = 30; // default 3407 + string mixed_precision = 31; // fp16, bf16, fp8, no + + // Dataset + string dataset_source = 40; // HF dataset ID, local file/dir path + string dataset_split = 41; // train, test, etc. + + // Output + string output_dir = 50; + string job_id = 51; // client-assigned or auto-generated + + // Resume training from a checkpoint + string resume_from_checkpoint = 55; // path to checkpoint dir to resume from + + // Backend-specific AND method-specific extensibility + map extra_options = 60; +} + +message FineTuneJobResult { + string job_id = 1; + bool success = 2; + string message = 3; +} + +message FineTuneProgressRequest { + string job_id = 1; +} + +message FineTuneProgressUpdate { + string job_id = 1; + int32 current_step = 2; + int32 total_steps = 3; + float current_epoch = 4; + float total_epochs = 5; + float loss = 6; + float learning_rate = 7; + float grad_norm = 8; + float eval_loss = 9; + float eta_seconds = 10; + float progress_percent = 11; + string status = 12; // queued, caching, loading_model, loading_dataset, training, saving, completed, failed, stopped + string message = 13; + string checkpoint_path = 14; // set when a checkpoint is saved + string sample_path = 15; // set when a sample is generated (video/image backends) + map extra_metrics = 16; // method-specific metrics +} + +message FineTuneStopRequest { + string job_id = 1; + bool save_checkpoint = 2; +} + +message ListCheckpointsRequest { + string output_dir = 1; +} + +message ListCheckpointsResponse { + repeated CheckpointInfo checkpoints = 1; +} + +message CheckpointInfo { + string path = 1; + int32 step = 2; + float epoch = 3; + float loss = 4; + string created_at = 5; +} + +message ExportModelRequest { + string checkpoint_path = 1; + string output_path = 2; + string export_format = 3; // lora, loha, lokr, merged_16bit, merged_4bit, gguf, diffusers + string quantization_method = 4; // for GGUF: q4_k_m, q5_k_m, q8_0, f16, etc. + string model = 5; // base model name (for merge operations) + map extra_options = 6; +} diff --git a/backend/index.yaml b/backend/index.yaml index 8dde88256..8a8232747 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -3030,3 +3030,54 @@ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral" mirrors: - localai/localai-backends:master-metal-darwin-arm64-voxtral +- &trl + name: "trl" + alias: "trl" + license: apache-2.0 + description: | + HuggingFace TRL fine-tuning backend. Supports SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO training methods. + Works on CPU and GPU. + urls: + - https://github.com/huggingface/trl + tags: + - fine-tuning + - LLM + - CPU + - GPU + - CUDA + capabilities: + default: "cpu-trl" + nvidia: "cuda12-trl" + nvidia-cuda-12: "cuda12-trl" + nvidia-cuda-13: "cuda13-trl" +## TRL backend images +- !!merge <<: *trl + name: "cpu-trl" + uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-trl" + mirrors: + - localai/localai-backends:latest-cpu-trl +- !!merge <<: *trl + name: "cpu-trl-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cpu-trl" + mirrors: + - localai/localai-backends:master-cpu-trl +- !!merge <<: *trl + name: "cuda12-trl" + uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda12-trl" + mirrors: + - localai/localai-backends:latest-cublas-cuda12-trl +- !!merge <<: *trl + name: "cuda12-trl-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda12-trl" + mirrors: + - localai/localai-backends:master-cublas-cuda12-trl +- !!merge <<: *trl + name: "cuda13-trl" + uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda13-trl" + mirrors: + - localai/localai-backends:latest-cublas-cuda13-trl +- !!merge <<: *trl + name: "cuda13-trl-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl" + mirrors: + - localai/localai-backends:master-cublas-cuda13-trl diff --git a/backend/python/trl/Makefile b/backend/python/trl/Makefile new file mode 100644 index 000000000..ababb961c --- /dev/null +++ b/backend/python/trl/Makefile @@ -0,0 +1,26 @@ +# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export) +LLAMA_CPP_CONVERT_VERSION ?= master + +.PHONY: trl +trl: + LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh + +.PHONY: run +run: trl + @echo "Running trl..." + bash run.sh + @echo "trl run." + +.PHONY: test +test: trl + @echo "Testing trl..." + bash test.sh + @echo "trl tested." + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ diff --git a/backend/python/trl/backend.py b/backend/python/trl/backend.py new file mode 100644 index 000000000..c414e6fb6 --- /dev/null +++ b/backend/python/trl/backend.py @@ -0,0 +1,860 @@ +#!/usr/bin/env python3 +""" +TRL fine-tuning backend for LocalAI. + +Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO) +using standard HuggingFace transformers + PEFT. Works on both CPU and GPU. +""" +import argparse +import json +import os +import queue +import signal +import sys +import threading +import time +import uuid +from concurrent import futures + +import grpc +import backend_pb2 +import backend_pb2_grpc + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4')) + + +class ProgressCallback: + """HuggingFace TrainerCallback that pushes progress updates to a queue.""" + + def __init__(self, job_id, progress_queue, total_epochs): + self.job_id = job_id + self.progress_queue = progress_queue + self.total_epochs = total_epochs + + def get_callback(self): + from transformers import TrainerCallback + + parent = self + + class _Callback(TrainerCallback): + def __init__(self): + self._train_start_time = None + + def on_train_begin(self, args, state, control, **kwargs): + self._train_start_time = time.time() + + def on_log(self, args, state, control, logs=None, **kwargs): + if logs is None: + return + total_steps = state.max_steps if state.max_steps > 0 else 0 + progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0 + eta = 0.0 + if state.global_step > 0 and total_steps > 0 and self._train_start_time: + elapsed = time.time() - self._train_start_time + remaining_steps = total_steps - state.global_step + if state.global_step > 0: + eta = remaining_steps * (elapsed / state.global_step) + + extra_metrics = {} + for k, v in logs.items(): + if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'): + extra_metrics[k] = float(v) + + update = backend_pb2.FineTuneProgressUpdate( + job_id=parent.job_id, + current_step=state.global_step, + total_steps=total_steps, + current_epoch=float(logs.get('epoch', 0)), + total_epochs=float(parent.total_epochs), + loss=float(logs.get('loss', 0)), + learning_rate=float(logs.get('learning_rate', 0)), + grad_norm=float(logs.get('grad_norm', 0)), + eval_loss=float(logs.get('eval_loss', 0)), + eta_seconds=float(eta), + progress_percent=float(progress), + status="training", + extra_metrics=extra_metrics, + ) + parent.progress_queue.put(update) + + def on_prediction_step(self, args, state, control, **kwargs): + """Send periodic updates during evaluation so the UI doesn't freeze.""" + if not hasattr(self, '_eval_update_counter'): + self._eval_update_counter = 0 + self._eval_update_counter += 1 + # Throttle: send an update every 10 prediction steps + if self._eval_update_counter % 10 != 0: + return + total_steps = state.max_steps if state.max_steps > 0 else 0 + progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0 + update = backend_pb2.FineTuneProgressUpdate( + job_id=parent.job_id, + current_step=state.global_step, + total_steps=total_steps, + current_epoch=float(state.epoch or 0), + total_epochs=float(parent.total_epochs), + progress_percent=float(progress), + status="training", + message=f"Evaluating... (batch {self._eval_update_counter})", + ) + parent.progress_queue.put(update) + + def on_evaluate(self, args, state, control, metrics=None, **kwargs): + """Report eval results once evaluation is done.""" + # Reset prediction counter for next eval round + self._eval_update_counter = 0 + + total_steps = state.max_steps if state.max_steps > 0 else 0 + progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0 + + eval_loss = 0.0 + extra_metrics = {} + if metrics: + eval_loss = float(metrics.get('eval_loss', 0)) + for k, v in metrics.items(): + if isinstance(v, (int, float)) and k not in ('eval_loss', 'epoch'): + extra_metrics[k] = float(v) + + update = backend_pb2.FineTuneProgressUpdate( + job_id=parent.job_id, + current_step=state.global_step, + total_steps=total_steps, + current_epoch=float(state.epoch or 0), + total_epochs=float(parent.total_epochs), + eval_loss=eval_loss, + progress_percent=float(progress), + status="training", + message=f"Evaluation complete at step {state.global_step}", + extra_metrics=extra_metrics, + ) + parent.progress_queue.put(update) + + def on_save(self, args, state, control, **kwargs): + checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") + update = backend_pb2.FineTuneProgressUpdate( + job_id=parent.job_id, + current_step=state.global_step, + status="saving", + message=f"Checkpoint saved at step {state.global_step}", + checkpoint_path=checkpoint_path, + ) + parent.progress_queue.put(update) + + def on_train_end(self, args, state, control, **kwargs): + update = backend_pb2.FineTuneProgressUpdate( + job_id=parent.job_id, + current_step=state.global_step, + total_steps=state.max_steps, + progress_percent=100.0, + status="completed", + message="Training completed", + ) + parent.progress_queue.put(update) + + return _Callback() + + +class ActiveJob: + """Represents an active fine-tuning job.""" + + def __init__(self, job_id): + self.job_id = job_id + self.progress_queue = queue.Queue() + self.trainer = None + self.thread = None + self.model = None + self.tokenizer = None + self.error = None + self.completed = False + self.stopped = False + + +def _is_gated_repo_error(exc): + """Check if an exception is caused by a gated HuggingFace repo requiring authentication.""" + try: + from huggingface_hub.utils import GatedRepoError + if isinstance(exc, GatedRepoError): + return True + except ImportError: + pass + msg = str(exc).lower() + if "gated repo" in msg or "access to model" in msg: + return True + if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'): + if exc.response.status_code in (401, 403): + return True + return False + + +class BackendServicer(backend_pb2_grpc.BackendServicer): + def __init__(self): + self.active_job = None + + def Health(self, request, context): + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + def LoadModel(self, request, context): + """Accept LoadModel — actual model loading happens in StartFineTune.""" + return backend_pb2.Result(success=True, message="OK") + + def StartFineTune(self, request, context): + if self.active_job is not None and not self.active_job.completed: + return backend_pb2.FineTuneJobResult( + job_id="", + success=False, + message="A fine-tuning job is already running", + ) + + job_id = request.job_id if request.job_id else str(uuid.uuid4()) + job = ActiveJob(job_id) + self.active_job = job + + # Start training in background thread + thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True) + job.thread = thread + thread.start() + + return backend_pb2.FineTuneJobResult( + job_id=job_id, + success=True, + message="Fine-tuning job started", + ) + + def _run_training(self, request, job): + try: + self._do_training(request, job) + except Exception as e: + if _is_gated_repo_error(e): + msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. " + "Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.") + else: + msg = f"Training failed: {e}" + job.error = msg + job.completed = True + update = backend_pb2.FineTuneProgressUpdate( + job_id=job.job_id, + status="failed", + message=msg, + ) + job.progress_queue.put(update) + # Send sentinel + job.progress_queue.put(None) + + def _do_training(self, request, job): + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + from datasets import load_dataset, Dataset + + extra = dict(request.extra_options) + training_method = request.training_method or "sft" + training_type = request.training_type or "lora" + + # Send loading status + job.progress_queue.put(backend_pb2.FineTuneProgressUpdate( + job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}", + )) + + # Determine device and dtype + device_map = "auto" if torch.cuda.is_available() else "cpu" + dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16 + + # HuggingFace token for gated repos (from extra_options or HF_TOKEN env) + hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN") + + # Load model + model_kwargs = {"device_map": device_map, "torch_dtype": dtype} + if hf_token: + model_kwargs["token"] = hf_token + if extra.get("trust_remote_code", "false").lower() == "true": + model_kwargs["trust_remote_code"] = True + if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available(): + from transformers import BitsAndBytesConfig + model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) + + model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + job.model = model + job.tokenizer = tokenizer + + # Apply LoRA if requested + if training_type == "lora": + from peft import LoraConfig, get_peft_model + lora_r = request.adapter_rank if request.adapter_rank > 0 else 16 + lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16 + lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0 + + target_modules = list(request.target_modules) if request.target_modules else None + peft_config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + target_modules=target_modules or "all-linear", + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, peft_config) + + # Load dataset + job.progress_queue.put(backend_pb2.FineTuneProgressUpdate( + job_id=job.job_id, status="loading_dataset", message="Loading dataset", + )) + + dataset_split = request.dataset_split or "train" + if os.path.exists(request.dataset_source): + if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'): + dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split) + elif request.dataset_source.endswith('.csv'): + dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split) + else: + dataset = load_dataset(request.dataset_source, split=dataset_split) + else: + dataset = load_dataset(request.dataset_source, split=dataset_split) + + # Eval dataset setup + eval_dataset = None + eval_strategy = extra.get("eval_strategy", "steps") + eval_steps = int(extra.get("eval_steps", str(request.save_steps if request.save_steps > 0 else 500))) + + if eval_strategy != "no": + eval_split = extra.get("eval_split") + eval_dataset_source = extra.get("eval_dataset_source") + if eval_split: + # Load a specific split as eval dataset + if os.path.exists(request.dataset_source): + if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'): + eval_dataset = load_dataset("json", data_files=request.dataset_source, split=eval_split) + elif request.dataset_source.endswith('.csv'): + eval_dataset = load_dataset("csv", data_files=request.dataset_source, split=eval_split) + else: + eval_dataset = load_dataset(request.dataset_source, split=eval_split) + else: + eval_dataset = load_dataset(request.dataset_source, split=eval_split) + elif eval_dataset_source: + # Load eval dataset from a separate source + eval_dataset = load_dataset(eval_dataset_source, split="train") + else: + # Auto-split from training set + eval_split_ratio = float(extra.get("eval_split_ratio", "0.1")) + split = dataset.train_test_split(test_size=eval_split_ratio) + dataset = split["train"] + eval_dataset = split["test"] + + if eval_strategy == "no": + eval_dataset = None + + # Training config + output_dir = request.output_dir or f"./output-{job.job_id}" + num_epochs = request.num_epochs if request.num_epochs > 0 else 3 + batch_size = request.batch_size if request.batch_size > 0 else 2 + lr = request.learning_rate if request.learning_rate > 0 else 2e-4 + grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4 + warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5 + weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01 + max_steps = request.max_steps if request.max_steps > 0 else -1 + save_steps = request.save_steps if request.save_steps > 0 else 500 + seed = request.seed if request.seed > 0 else 3407 + optimizer = request.optimizer or "adamw_torch" + + # Checkpoint save controls + save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited + save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no + + # CPU vs GPU training args (can be overridden via extra_options) + use_cpu = not torch.cuda.is_available() + common_train_kwargs = {} + if use_cpu: + common_train_kwargs["use_cpu"] = True + common_train_kwargs["fp16"] = False + common_train_kwargs["bf16"] = False + common_train_kwargs["gradient_checkpointing"] = False + else: + common_train_kwargs["bf16"] = True + common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing + + # Allow extra_options to override training kwargs + for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"): + if flag in extra: + common_train_kwargs[flag] = extra[flag].lower() == "true" + + # Create progress callback + progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs) + + # Build save kwargs (shared across all methods) + _save_kwargs = {} + if save_strategy == "steps" and save_steps > 0: + _save_kwargs["save_steps"] = save_steps + _save_kwargs["save_strategy"] = "steps" + elif save_strategy == "epoch": + _save_kwargs["save_strategy"] = "epoch" + elif save_strategy == "no": + _save_kwargs["save_strategy"] = "no" + else: + _save_kwargs["save_steps"] = save_steps + _save_kwargs["save_strategy"] = "steps" + if save_total_limit: + _save_kwargs["save_total_limit"] = save_total_limit + + # Eval kwargs + _eval_kwargs = {} + if eval_dataset is not None: + _eval_kwargs["eval_strategy"] = eval_strategy + _eval_kwargs["eval_steps"] = eval_steps + + # Common training arguments shared by all methods + _common_args = dict( + output_dir=output_dir, + num_train_epochs=num_epochs, + per_device_train_batch_size=batch_size, + learning_rate=lr, + gradient_accumulation_steps=grad_accum, + warmup_steps=warmup_steps, + weight_decay=weight_decay, + max_steps=max_steps, + seed=seed, + optim=optimizer, + logging_steps=1, + report_to="none", + **_save_kwargs, + **common_train_kwargs, + **_eval_kwargs, + ) + + # Select trainer based on training method + if training_method == "sft": + from trl import SFTTrainer, SFTConfig + + max_length = int(extra.get("max_seq_length", "512")) + packing = extra.get("packing", "false").lower() == "true" + + training_args = SFTConfig( + max_length=max_length, + packing=packing, + **_common_args, + ) + + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "dpo": + from trl import DPOTrainer, DPOConfig + + beta = float(extra.get("beta", "0.1")) + loss_type = extra.get("loss_type", "sigmoid") + max_length = int(extra.get("max_length", "512")) + + training_args = DPOConfig( + beta=beta, + loss_type=loss_type, + max_length=max_length, + **_common_args, + ) + + trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "grpo": + from trl import GRPOTrainer, GRPOConfig + + num_generations = int(extra.get("num_generations", "4")) + max_completion_length = int(extra.get("max_completion_length", "256")) + + training_args = GRPOConfig( + num_generations=num_generations, + max_completion_length=max_completion_length, + **_common_args, + ) + + # GRPO requires reward functions passed via extra_options as a JSON list + from reward_functions import build_reward_functions + + reward_funcs = [] + if extra.get("reward_funcs"): + reward_funcs = build_reward_functions(extra["reward_funcs"]) + + if not reward_funcs: + raise ValueError( + "GRPO requires at least one reward function. " + "Specify reward_functions in the request or " + "reward_funcs in extra_options." + ) + + trainer = GRPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + reward_funcs=reward_funcs, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "orpo": + from trl import ORPOTrainer, ORPOConfig + + beta = float(extra.get("beta", "0.1")) + max_length = int(extra.get("max_length", "512")) + + training_args = ORPOConfig( + beta=beta, + max_length=max_length, + **_common_args, + ) + + trainer = ORPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "kto": + from trl import KTOTrainer, KTOConfig + + beta = float(extra.get("beta", "0.1")) + max_length = int(extra.get("max_length", "512")) + + training_args = KTOConfig( + beta=beta, + max_length=max_length, + **_common_args, + ) + + trainer = KTOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "rloo": + from trl import RLOOTrainer, RLOOConfig + + num_generations = int(extra.get("num_generations", "4")) + max_completion_length = int(extra.get("max_completion_length", "256")) + + training_args = RLOOConfig( + num_generations=num_generations, + max_new_tokens=max_completion_length, + **_common_args, + ) + + trainer = RLOOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "reward": + from trl import RewardTrainer, RewardConfig + + max_length = int(extra.get("max_length", "512")) + + training_args = RewardConfig( + max_length=max_length, + **_common_args, + ) + + trainer = RewardTrainer( + model=model, + args=training_args, + train_dataset=dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + else: + raise ValueError(f"Unsupported training method: {training_method}. " + "Supported: sft, dpo, grpo, orpo, kto, rloo, reward") + + job.trainer = trainer + + # Start training + job.progress_queue.put(backend_pb2.FineTuneProgressUpdate( + job_id=job.job_id, status="training", message="Training started", + )) + + resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None + trainer.train(resume_from_checkpoint=resume_ckpt) + + # Save final model + trainer.save_model(output_dir) + if tokenizer: + tokenizer.save_pretrained(output_dir) + + job.completed = True + # Sentinel to signal stream end + job.progress_queue.put(None) + + def FineTuneProgress(self, request, context): + if self.active_job is None or self.active_job.job_id != request.job_id: + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(f"Job {request.job_id} not found") + return + + job = self.active_job + while True: + try: + update = job.progress_queue.get(timeout=1.0) + if update is None: + break + yield update + if update.status in ("completed", "failed", "stopped"): + break + except queue.Empty: + if job.completed or job.stopped: + break + if not context.is_active(): + break + continue + + def StopFineTune(self, request, context): + # Stopping is handled by killing the process from Go via ShutdownModel. + return backend_pb2.Result(success=True, message="OK") + + def ListCheckpoints(self, request, context): + output_dir = request.output_dir + if not os.path.isdir(output_dir): + return backend_pb2.ListCheckpointsResponse(checkpoints=[]) + + checkpoints = [] + for entry in sorted(os.listdir(output_dir)): + if entry.startswith("checkpoint-"): + ckpt_path = os.path.join(output_dir, entry) + if not os.path.isdir(ckpt_path): + continue + step = 0 + try: + step = int(entry.split("-")[1]) + except (IndexError, ValueError): + pass + + # Try to read trainer_state.json for metadata + loss = 0.0 + epoch = 0.0 + state_file = os.path.join(ckpt_path, "trainer_state.json") + if os.path.exists(state_file): + try: + with open(state_file) as f: + state = json.load(f) + if state.get("log_history"): + last_log = state["log_history"][-1] + loss = last_log.get("loss", 0.0) + epoch = last_log.get("epoch", 0.0) + except Exception: + pass + + created_at = time.strftime( + "%Y-%m-%dT%H:%M:%SZ", + time.gmtime(os.path.getmtime(ckpt_path)), + ) + + checkpoints.append(backend_pb2.CheckpointInfo( + path=ckpt_path, + step=step, + epoch=float(epoch), + loss=float(loss), + created_at=created_at, + )) + + return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints) + + def ExportModel(self, request, context): + export_format = request.export_format or "lora" + output_path = request.output_path + checkpoint_path = request.checkpoint_path + + # Extract HF token for gated model access + extra = dict(request.extra_options) if request.extra_options else {} + hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN") + + if not checkpoint_path or not os.path.isdir(checkpoint_path): + return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}") + + os.makedirs(output_path, exist_ok=True) + + try: + if export_format == "lora": + # Just copy the adapter files + import shutil + for f in os.listdir(checkpoint_path): + src = os.path.join(checkpoint_path, f) + dst = os.path.join(output_path, f) + if os.path.isfile(src): + shutil.copy2(src, dst) + + elif export_format in ("merged_16bit", "merged_4bit"): + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + from peft import PeftModel + + base_model_name = request.model + if not base_model_name: + return backend_pb2.Result(success=False, message="Base model name required for merge export") + + dtype = torch.float16 if export_format == "merged_16bit" else torch.float32 + base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token) + model = PeftModel.from_pretrained(base_model, checkpoint_path) + merged = model.merge_and_unload() + merged.save_pretrained(output_path) + + tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token) + tokenizer.save_pretrained(output_path) + + elif export_format == "gguf": + import torch + import subprocess + import shutil + from transformers import AutoModelForCausalLM, AutoTokenizer + from peft import PeftModel + + base_model_name = request.model + if not base_model_name: + return backend_pb2.Result(success=False, message="Base model name required for GGUF export") + + # Step 1: Merge LoRA into base model + merge_dir = os.path.join(output_path, "_hf_merged") + os.makedirs(merge_dir, exist_ok=True) + + base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token) + model = PeftModel.from_pretrained(base_model, checkpoint_path) + merged = model.merge_and_unload() + merged.save_pretrained(merge_dir) + + tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token) + tokenizer.save_pretrained(merge_dir) + + # Ensure tokenizer.model (SentencePiece) is present in merge_dir. + # Gemma models need this file for GGUF conversion to use the + # SentencePiece path; without it, the script falls back to BPE + # handling which fails on unrecognized pre-tokenizer hashes. + sp_model_path = os.path.join(merge_dir, "tokenizer.model") + if not os.path.exists(sp_model_path): + sp_copied = False + # Method 1: Load the slow tokenizer which keeps the SP model file + try: + slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token) + if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file): + import shutil as _shutil + _shutil.copy2(slow_tok.vocab_file, sp_model_path) + sp_copied = True + print(f"Copied tokenizer.model from slow tokenizer cache") + except Exception as e: + print(f"Slow tokenizer method failed: {e}") + # Method 2: Download from HF hub + if not sp_copied: + try: + from huggingface_hub import hf_hub_download + cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token) + import shutil as _shutil + _shutil.copy2(cached_sp, sp_model_path) + sp_copied = True + print(f"Copied tokenizer.model from HF hub") + except Exception as e: + print(f"HF hub download method failed: {e}") + if not sp_copied: + print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. " + "GGUF conversion may fail for SentencePiece models.") + + # Free GPU memory before conversion + del merged, model, base_model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Step 2: Convert to GGUF using convert_hf_to_gguf.py + quant = request.quantization_method or "auto" + outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"} + outtype = outtype_map.get(quant, "f16") + + gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf" + gguf_path = os.path.join(output_path, gguf_filename) + + script_dir = os.path.dirname(os.path.abspath(__file__)) + convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py") + if not os.path.exists(convert_script): + return backend_pb2.Result(success=False, + message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.") + + # Log merge_dir contents for debugging conversion issues + merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else [] + print(f"Merge dir contents: {merge_files}", flush=True) + + env = os.environ.copy() + env["NO_LOCAL_GGUF"] = "1" + cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path] + conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env) + if conv_result.returncode != 0: + diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}" + return backend_pb2.Result(success=False, + message=f"GGUF conversion failed: {diag}") + + # Clean up intermediate merged model + shutil.rmtree(merge_dir, ignore_errors=True) + else: + return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}") + + except Exception as e: + if _is_gated_repo_error(e): + return backend_pb2.Result(success=False, + message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. " + "Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.") + return backend_pb2.Result(success=False, message=f"Export failed: {e}") + + return backend_pb2.Result(success=True, message=f"Model exported to {output_path}") + + +def serve(address): + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ('grpc.max_message_length', 50 * 1024 * 1024), + ('grpc.max_send_message_length', 50 * 1024 * 1024), + ('grpc.max_receive_message_length', 50 * 1024 * 1024), + ], + ) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True) + + # Handle graceful shutdown + def stop(signum, frame): + server.stop(0) + sys.exit(0) + + signal.signal(signal.SIGTERM, stop) + signal.signal(signal.SIGINT, stop) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend") + parser.add_argument("--addr", default="localhost:50051", help="gRPC server address") + args = parser.parse_args() + serve(args.addr) diff --git a/backend/python/trl/install.sh b/backend/python/trl/install.sh new file mode 100644 index 000000000..6963e60ed --- /dev/null +++ b/backend/python/trl/install.sh @@ -0,0 +1,37 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" +installRequirements + +# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version +LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}" +CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py" +if [ ! -f "${CONVERT_SCRIPT}" ]; then + echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..." + curl -L --fail --retry 3 \ + "https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \ + -o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available." +fi + +# Install gguf package from the same llama.cpp commit to keep them in sync +GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py" +echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..." +if [ "x${USE_PIP:-}" == "xtrue" ]; then + pip install "${GGUF_PIP_SPEC}" || { + echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..." + pip install "gguf>=0.16.0" + } +else + uv pip install "${GGUF_PIP_SPEC}" || { + echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..." + uv pip install "gguf>=0.16.0" + } +fi diff --git a/backend/python/trl/requirements-cpu.txt b/backend/python/trl/requirements-cpu.txt new file mode 100644 index 000000000..c67858542 --- /dev/null +++ b/backend/python/trl/requirements-cpu.txt @@ -0,0 +1,9 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.10.0 +trl +peft +datasets>=3.0.0 +transformers>=4.56.2 +accelerate>=1.4.0 +huggingface-hub>=1.3.0 +sentencepiece diff --git a/backend/python/trl/requirements-cublas12.txt b/backend/python/trl/requirements-cublas12.txt new file mode 100644 index 000000000..05f29591c --- /dev/null +++ b/backend/python/trl/requirements-cublas12.txt @@ -0,0 +1,9 @@ +torch==2.10.0 +trl +peft +datasets>=3.0.0 +transformers>=4.56.2 +accelerate>=1.4.0 +huggingface-hub>=1.3.0 +sentencepiece +bitsandbytes diff --git a/backend/python/trl/requirements-cublas13.txt b/backend/python/trl/requirements-cublas13.txt new file mode 100644 index 000000000..05f29591c --- /dev/null +++ b/backend/python/trl/requirements-cublas13.txt @@ -0,0 +1,9 @@ +torch==2.10.0 +trl +peft +datasets>=3.0.0 +transformers>=4.56.2 +accelerate>=1.4.0 +huggingface-hub>=1.3.0 +sentencepiece +bitsandbytes diff --git a/backend/python/trl/requirements.txt b/backend/python/trl/requirements.txt new file mode 100644 index 000000000..0834a8fcd --- /dev/null +++ b/backend/python/trl/requirements.txt @@ -0,0 +1,3 @@ +grpcio==1.78.1 +protobuf +certifi diff --git a/backend/python/trl/reward_functions.py b/backend/python/trl/reward_functions.py new file mode 100644 index 000000000..12074f80c --- /dev/null +++ b/backend/python/trl/reward_functions.py @@ -0,0 +1,236 @@ +""" +Built-in reward functions and inline function compiler for GRPO training. + +All reward functions follow TRL's signature: (completions, **kwargs) -> list[float] +""" + +import json +import re +import math +import string +import functools + + +# --------------------------------------------------------------------------- +# Built-in reward functions +# --------------------------------------------------------------------------- + +def format_reward(completions, **kwargs): + """Checks for ... followed by an answer. Returns 1.0 or 0.0.""" + pattern = re.compile(r".*?\s*\S", re.DOTALL) + return [1.0 if pattern.search(c) else 0.0 for c in completions] + + +def reasoning_accuracy_reward(completions, **kwargs): + """Extracts ... content and compares to the expected answer.""" + answers = kwargs.get("answer", []) + if not answers: + return [0.0] * len(completions) + + pattern = re.compile(r"(.*?)", re.DOTALL) + scores = [] + for i, c in enumerate(completions): + expected = answers[i] if i < len(answers) else "" + match = pattern.search(c) + if match: + extracted = match.group(1).strip() + scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0) + else: + scores.append(0.0) + return scores + + +def length_reward(completions, target_length=200, **kwargs): + """Score based on proximity to target_length. Returns [0, 1].""" + scores = [] + for c in completions: + length = len(c) + if target_length <= 0: + scores.append(0.0) + else: + diff = abs(length - target_length) / target_length + scores.append(max(0.0, 1.0 - diff)) + return scores + + +def xml_tag_reward(completions, **kwargs): + """Scores properly opened/closed XML tags (, ).""" + tags = ["think", "answer"] + scores = [] + for c in completions: + tag_score = 0.0 + for tag in tags: + if f"<{tag}>" in c and f"" in c: + tag_score += 0.5 + scores.append(min(tag_score, 1.0)) + return scores + + +def no_repetition_reward(completions, n=4, **kwargs): + """Penalizes n-gram repetition. Returns [0, 1].""" + scores = [] + for c in completions: + words = c.split() + if len(words) < n: + scores.append(1.0) + continue + ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)] + unique = len(set(ngrams)) + total = len(ngrams) + scores.append(unique / total if total > 0 else 1.0) + return scores + + +def code_execution_reward(completions, **kwargs): + """Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0.""" + pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL) + scores = [] + for c in completions: + match = pattern.search(c) + if not match: + scores.append(0.0) + continue + code = match.group(1) + try: + compile(code, "", "exec") + scores.append(1.0) + except SyntaxError: + scores.append(0.0) + return scores + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + +BUILTIN_REGISTRY = { + "format_reward": format_reward, + "reasoning_accuracy_reward": reasoning_accuracy_reward, + "length_reward": length_reward, + "xml_tag_reward": xml_tag_reward, + "no_repetition_reward": no_repetition_reward, + "code_execution_reward": code_execution_reward, +} + + +# --------------------------------------------------------------------------- +# Inline function compiler +# --------------------------------------------------------------------------- + +_SAFE_BUILTINS = { + "len": len, "int": int, "float": float, "str": str, "bool": bool, + "list": list, "dict": dict, "tuple": tuple, "set": set, + "range": range, "enumerate": enumerate, "zip": zip, + "map": map, "filter": filter, "sorted": sorted, + "min": min, "max": max, "sum": sum, "abs": abs, "round": round, + "any": any, "all": all, "isinstance": isinstance, "type": type, + "print": print, "True": True, "False": False, "None": None, + "ValueError": ValueError, "TypeError": TypeError, + "KeyError": KeyError, "IndexError": IndexError, +} + + +def compile_inline_reward(name, code): + """Compile user-provided code into a reward function. + + The code should be the body of a function that receives + `completions` (list[str]) and `**kwargs`, and returns list[float]. + + Available modules: re, math, json, string. + """ + func_source = ( + f"def _user_reward_{name}(completions, **kwargs):\n" + + "\n".join(f" {line}" for line in code.splitlines()) + ) + + restricted_globals = { + "__builtins__": _SAFE_BUILTINS, + "re": re, + "math": math, + "json": json, + "string": string, + } + + try: + compiled = compile(func_source, f"", "exec") + except SyntaxError as e: + raise ValueError(f"Syntax error in inline reward function '{name}': {e}") + + exec(compiled, restricted_globals) + func = restricted_globals[f"_user_reward_{name}"] + + # Validate with a quick smoke test + try: + result = func(["test"], answer=["test"]) + if not isinstance(result, list): + raise ValueError( + f"Inline reward function '{name}' must return a list, got {type(result).__name__}" + ) + except Exception as e: + if "must return a list" in str(e): + raise + # Other errors during smoke test are acceptable (e.g. missing kwargs) + pass + + return func + + +# --------------------------------------------------------------------------- +# Dispatcher +# --------------------------------------------------------------------------- + +def build_reward_functions(specs_json): + """Parse a JSON list of reward function specs and return a list of callables. + + Each spec is a dict with: + - type: "builtin" or "inline" + - name: function name + - code: (inline only) Python function body + - params: (optional) dict of string params applied via functools.partial + """ + if isinstance(specs_json, str): + specs = json.loads(specs_json) + else: + specs = specs_json + + if not isinstance(specs, list): + raise ValueError("reward_funcs must be a JSON array of reward function specs") + + reward_funcs = [] + for spec in specs: + spec_type = spec.get("type", "builtin") + name = spec.get("name", "") + params = spec.get("params", {}) + + if spec_type == "builtin": + if name not in BUILTIN_REGISTRY: + available = ", ".join(sorted(BUILTIN_REGISTRY.keys())) + raise ValueError( + f"Unknown builtin reward function '{name}'. Available: {available}" + ) + func = BUILTIN_REGISTRY[name] + if params: + # Convert string params to appropriate types + typed_params = {} + for k, v in params.items(): + try: + typed_params[k] = int(v) + except (ValueError, TypeError): + try: + typed_params[k] = float(v) + except (ValueError, TypeError): + typed_params[k] = v + func = functools.partial(func, **typed_params) + reward_funcs.append(func) + + elif spec_type == "inline": + code = spec.get("code", "") + if not code.strip(): + raise ValueError(f"Inline reward function '{name}' has no code") + func = compile_inline_reward(name, code) + reward_funcs.append(func) + + else: + raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'") + + return reward_funcs diff --git a/backend/python/trl/run.sh b/backend/python/trl/run.sh new file mode 100644 index 000000000..bd17c6e1d --- /dev/null +++ b/backend/python/trl/run.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ diff --git a/backend/python/trl/test.py b/backend/python/trl/test.py new file mode 100644 index 000000000..d77d4e9f0 --- /dev/null +++ b/backend/python/trl/test.py @@ -0,0 +1,58 @@ +""" +Test script for the TRL fine-tuning gRPC backend. +""" +import unittest +import subprocess +import time + +import grpc +import backend_pb2 +import backend_pb2_grpc + + +class TestBackendServicer(unittest.TestCase): + """Tests for the TRL fine-tuning gRPC service.""" + + def setUp(self): + self.service = subprocess.Popen( + ["python3", "backend.py", "--addr", "localhost:50051"] + ) + time.sleep(10) + + def tearDown(self): + self.service.kill() + self.service.wait() + + def test_server_startup(self): + """Test that the server starts and responds to health checks.""" + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b'OK') + except Exception as err: + print(err) + self.fail("Server failed to start") + finally: + self.tearDown() + + def test_list_checkpoints_empty(self): + """Test listing checkpoints on a non-existent directory.""" + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.ListCheckpoints( + backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent") + ) + self.assertEqual(len(response.checkpoints), 0) + except Exception as err: + print(err) + self.fail("ListCheckpoints service failed") + finally: + self.tearDown() + + +if __name__ == '__main__': + unittest.main() diff --git a/backend/python/trl/test.sh b/backend/python/trl/test.sh new file mode 100644 index 000000000..eb59f2aaf --- /dev/null +++ b/backend/python/trl/test.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests diff --git a/core/cli/run.go b/core/cli/run.go index c614e123d..000dd3366 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -121,6 +121,9 @@ type RunCMD struct { AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"` AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"` + // Fine-tuning + EnableFineTuning bool `env:"LOCALAI_ENABLE_FINETUNING" default:"false" help:"Enable fine-tuning support" group:"finetuning"` + // Authentication AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"` AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"` @@ -326,6 +329,11 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { opts = append(opts, config.WithAgentHubURL(r.AgentHubURL)) } + // Fine-tuning + if r.EnableFineTuning { + opts = append(opts, config.EnableFineTuning) + } + // Authentication authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != "" if authEnabled { diff --git a/core/config/application_config.go b/core/config/application_config.go index 9c1be82d9..bb187be43 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -97,6 +97,9 @@ type ApplicationConfig struct { // Agent Pool (LocalAGI integration) AgentPool AgentPoolConfig + // Fine-tuning + FineTuning FineTuningConfig + // Authentication & Authorization Auth AuthConfig } @@ -142,6 +145,11 @@ type AgentPoolConfig struct { AgentHubURL string // default: "https://agenthub.localai.io" } +// FineTuningConfig holds configuration for fine-tuning support. +type FineTuningConfig struct { + Enabled bool +} + type AppOption func(*ApplicationConfig) func NewApplicationConfig(o ...AppOption) *ApplicationConfig { @@ -733,6 +741,12 @@ func WithAgentHubURL(url string) AppOption { } } +// Fine-tuning options + +var EnableFineTuning = func(o *ApplicationConfig) { + o.FineTuning.Enabled = true +} + // Auth options func WithAuthEnabled(enabled bool) AppOption { diff --git a/core/gallery/importers/local.go b/core/gallery/importers/local.go new file mode 100644 index 000000000..2a456cc60 --- /dev/null +++ b/core/gallery/importers/local.go @@ -0,0 +1,205 @@ +package importers + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/xlog" +) + +// ImportLocalPath scans a local directory for exported model files and produces +// a config.ModelConfig with the correct backend, model path, and options. +// Paths in the returned config are relative to modelsPath when possible so that +// the YAML config remains portable. +// +// Detection order: +// 1. GGUF files (*.gguf) — uses llama-cpp backend +// 2. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter +// 3. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend +func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) { + // Make paths relative to the models directory (parent of dirPath) + // so config YAML stays portable. + modelsDir := filepath.Dir(dirPath) + relPath := func(absPath string) string { + if rel, err := filepath.Rel(modelsDir, absPath); err == nil { + return rel + } + return absPath + } + + // 1. GGUF: check dirPath and dirPath_gguf/ (Unsloth convention) + ggufFile := findGGUF(dirPath) + if ggufFile == "" { + ggufSubdir := dirPath + "_gguf" + ggufFile = findGGUF(ggufSubdir) + } + if ggufFile != "" { + xlog.Info("ImportLocalPath: detected GGUF model", "path", ggufFile) + cfg := &config.ModelConfig{ + Name: name, + Backend: "llama-cpp", + KnownUsecaseStrings: []string{"chat"}, + Options: []string{"use_jinja:true"}, + } + cfg.Model = relPath(ggufFile) + cfg.TemplateConfig.UseTokenizerTemplate = true + cfg.Description = buildDescription(dirPath, "GGUF") + return cfg, nil + } + + // 2. LoRA adapter: look for adapter_config.json + + adapterConfigPath := filepath.Join(dirPath, "adapter_config.json") + if fileExists(adapterConfigPath) { + xlog.Info("ImportLocalPath: detected LoRA adapter", "path", dirPath) + baseModel := readBaseModel(dirPath) + cfg := &config.ModelConfig{ + Name: name, + Backend: "transformers", + KnownUsecaseStrings: []string{"chat"}, + } + cfg.Model = baseModel + cfg.TemplateConfig.UseTokenizerTemplate = true + cfg.LLMConfig.LoraAdapter = relPath(dirPath) + cfg.Description = buildDescription(dirPath, "LoRA adapter") + return cfg, nil + } + + // Also check for adapter_model.safetensors or adapter_model.bin without adapter_config.json + if fileExists(filepath.Join(dirPath, "adapter_model.safetensors")) || fileExists(filepath.Join(dirPath, "adapter_model.bin")) { + xlog.Info("ImportLocalPath: detected LoRA adapter (by model files)", "path", dirPath) + baseModel := readBaseModel(dirPath) + cfg := &config.ModelConfig{ + Name: name, + Backend: "transformers", + KnownUsecaseStrings: []string{"chat"}, + } + cfg.Model = baseModel + cfg.TemplateConfig.UseTokenizerTemplate = true + cfg.LLMConfig.LoraAdapter = relPath(dirPath) + cfg.Description = buildDescription(dirPath, "LoRA adapter") + return cfg, nil + } + + // 3. Merged model: *.safetensors or pytorch_model*.bin + config.json + if fileExists(filepath.Join(dirPath, "config.json")) && (hasFileWithSuffix(dirPath, ".safetensors") || hasFileWithPrefix(dirPath, "pytorch_model")) { + xlog.Info("ImportLocalPath: detected merged model", "path", dirPath) + cfg := &config.ModelConfig{ + Name: name, + Backend: "transformers", + KnownUsecaseStrings: []string{"chat"}, + } + cfg.Model = relPath(dirPath) + cfg.TemplateConfig.UseTokenizerTemplate = true + cfg.Description = buildDescription(dirPath, "merged model") + return cfg, nil + } + + return nil, fmt.Errorf("could not detect model format in directory %s", dirPath) +} + +// findGGUF returns the path to the first .gguf file found in dir, or "". +func findGGUF(dir string) string { + entries, err := os.ReadDir(dir) + if err != nil { + return "" + } + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), ".gguf") { + return filepath.Join(dir, e.Name()) + } + } + return "" +} + +// readBaseModel reads the base model name from adapter_config.json or export_metadata.json. +func readBaseModel(dirPath string) string { + // Try adapter_config.json → base_model_name_or_path (TRL writes this) + if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil { + var ac map[string]any + if json.Unmarshal(data, &ac) == nil { + if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" { + return bm + } + } + } + + // Try export_metadata.json → base_model (Unsloth writes this) + if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil { + var meta map[string]any + if json.Unmarshal(data, &meta) == nil { + if bm, ok := meta["base_model"].(string); ok && bm != "" { + return bm + } + } + } + + return "" +} + +// buildDescription creates a human-readable description using available metadata. +func buildDescription(dirPath, formatLabel string) string { + base := "" + + // Try adapter_config.json + if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil { + var ac map[string]any + if json.Unmarshal(data, &ac) == nil { + if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" { + base = bm + } + } + } + + // Try export_metadata.json + if base == "" { + if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil { + var meta map[string]any + if json.Unmarshal(data, &meta) == nil { + if bm, ok := meta["base_model"].(string); ok && bm != "" { + base = bm + } + } + } + } + + if base != "" { + return fmt.Sprintf("Fine-tuned from %s (%s)", base, formatLabel) + } + return fmt.Sprintf("Fine-tuned model (%s)", formatLabel) +} + +func fileExists(path string) bool { + info, err := os.Stat(path) + return err == nil && !info.IsDir() +} + +func hasFileWithSuffix(dir, suffix string) bool { + entries, err := os.ReadDir(dir) + if err != nil { + return false + } + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), suffix) { + return true + } + } + return false +} + +func hasFileWithPrefix(dir, prefix string) bool { + entries, err := os.ReadDir(dir) + if err != nil { + return false + } + for _, e := range entries { + if !e.IsDir() && strings.HasPrefix(e.Name(), prefix) { + return true + } + } + return false +} diff --git a/core/gallery/importers/local_test.go b/core/gallery/importers/local_test.go new file mode 100644 index 000000000..0de679462 --- /dev/null +++ b/core/gallery/importers/local_test.go @@ -0,0 +1,148 @@ +package importers_test + +import ( + "encoding/json" + "os" + "path/filepath" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/gallery/importers" +) + +var _ = Describe("ImportLocalPath", func() { + var tmpDir string + + BeforeEach(func() { + var err error + tmpDir, err = os.MkdirTemp("", "importers-local-test") + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + os.RemoveAll(tmpDir) + }) + + Context("GGUF detection", func() { + It("detects a GGUF file in the directory", func() { + modelDir := filepath.Join(tmpDir, "my-model") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(modelDir, "model-q4_k_m.gguf"), []byte("fake"), 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "my-model") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Backend).To(Equal("llama-cpp")) + Expect(cfg.Model).To(ContainSubstring(".gguf")) + Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue()) + Expect(cfg.KnownUsecaseStrings).To(ContainElement("chat")) + Expect(cfg.Options).To(ContainElement("use_jinja:true")) + }) + + It("detects GGUF in _gguf subdirectory", func() { + modelDir := filepath.Join(tmpDir, "my-model") + ggufDir := modelDir + "_gguf" + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + Expect(os.MkdirAll(ggufDir, 0755)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(ggufDir, "model.gguf"), []byte("fake"), 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "my-model") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Backend).To(Equal("llama-cpp")) + }) + }) + + Context("LoRA adapter detection", func() { + It("detects LoRA adapter via adapter_config.json", func() { + modelDir := filepath.Join(tmpDir, "lora-model") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + + adapterConfig := map[string]any{ + "base_model_name_or_path": "meta-llama/Llama-2-7b-hf", + "peft_type": "LORA", + } + data, _ := json.Marshal(adapterConfig) + Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "lora-model") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Backend).To(Equal("transformers")) + Expect(cfg.Model).To(Equal("meta-llama/Llama-2-7b-hf")) + Expect(cfg.LLMConfig.LoraAdapter).To(Equal("lora-model")) + Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue()) + }) + + It("reads base model from export_metadata.json as fallback", func() { + modelDir := filepath.Join(tmpDir, "lora-unsloth") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + + adapterConfig := map[string]any{"peft_type": "LORA"} + data, _ := json.Marshal(adapterConfig) + Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed()) + + metadata := map[string]any{"base_model": "unsloth/tinyllama-bnb-4bit"} + data, _ = json.Marshal(metadata) + Expect(os.WriteFile(filepath.Join(modelDir, "export_metadata.json"), data, 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "lora-unsloth") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Model).To(Equal("unsloth/tinyllama-bnb-4bit")) + }) + }) + + Context("Merged model detection", func() { + It("detects merged model with safetensors + config.json", func() { + modelDir := filepath.Join(tmpDir, "merged") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("fake"), 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "merged") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Backend).To(Equal("transformers")) + Expect(cfg.Model).To(Equal("merged")) + Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue()) + }) + + It("detects merged model with pytorch_model files", func() { + modelDir := filepath.Join(tmpDir, "merged-pt") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(modelDir, "pytorch_model-00001-of-00002.bin"), []byte("fake"), 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "merged-pt") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Backend).To(Equal("transformers")) + Expect(cfg.Model).To(Equal("merged-pt")) + }) + }) + + Context("fallback", func() { + It("returns error for empty directory", func() { + modelDir := filepath.Join(tmpDir, "empty") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + + _, err := importers.ImportLocalPath(modelDir, "empty") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("could not detect model format")) + }) + }) + + Context("description", func() { + It("includes base model name in description", func() { + modelDir := filepath.Join(tmpDir, "desc-test") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + + adapterConfig := map[string]any{ + "base_model_name_or_path": "TinyLlama/TinyLlama-1.1B", + } + data, _ := json.Marshal(adapterConfig) + Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "desc-test") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Description).To(ContainSubstring("TinyLlama/TinyLlama-1.1B")) + Expect(cfg.Description).To(ContainSubstring("Fine-tuned from")) + }) + }) +}) diff --git a/core/http/app.go b/core/http/app.go index e2da479b5..696d394d2 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -302,6 +302,17 @@ func API(application *application.Application) (*echo.Echo, error) { mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP) routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw, mcpMw) routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw) + // Fine-tuning routes + if application.ApplicationConfig().FineTuning.Enabled { + fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning) + ftService := services.NewFineTuneService( + application.ApplicationConfig(), + application.ModelLoader(), + application.ModelConfigLoader(), + ) + routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw) + } + routes.RegisterOpenAIRoutes(e, requestExtractor, application) routes.RegisterAnthropicRoutes(e, requestExtractor, application) routes.RegisterOpenResponsesRoutes(e, requestExtractor, application) diff --git a/core/http/auth/features.go b/core/http/auth/features.go index 85b4e60ec..1d7ff4f61 100644 --- a/core/http/auth/features.go +++ b/core/http/auth/features.go @@ -85,6 +85,18 @@ var RouteFeatureRegistry = []RouteFeature{ {"POST", "/stores/delete", FeatureStores}, {"POST", "/stores/get", FeatureStores}, {"POST", "/stores/find", FeatureStores}, + + // Fine-tuning + {"POST", "/api/fine-tuning/jobs", FeatureFineTuning}, + {"GET", "/api/fine-tuning/jobs", FeatureFineTuning}, + {"GET", "/api/fine-tuning/jobs/:id", FeatureFineTuning}, + {"POST", "/api/fine-tuning/jobs/:id/stop", FeatureFineTuning}, + {"DELETE", "/api/fine-tuning/jobs/:id", FeatureFineTuning}, + {"GET", "/api/fine-tuning/jobs/:id/progress", FeatureFineTuning}, + {"GET", "/api/fine-tuning/jobs/:id/checkpoints", FeatureFineTuning}, + {"POST", "/api/fine-tuning/jobs/:id/export", FeatureFineTuning}, + {"GET", "/api/fine-tuning/jobs/:id/download", FeatureFineTuning}, + {"POST", "/api/fine-tuning/datasets", FeatureFineTuning}, } // FeatureMeta describes a feature for the admin API/UI. @@ -104,6 +116,13 @@ func AgentFeatureMetas() []FeatureMeta { } } +// GeneralFeatureMetas returns metadata for general features. +func GeneralFeatureMetas() []FeatureMeta { + return []FeatureMeta{ + {FeatureFineTuning, "Fine-Tuning", false}, + } +} + // APIFeatureMetas returns metadata for API endpoint features. func APIFeatureMetas() []FeatureMeta { return []FeatureMeta{ diff --git a/core/http/auth/permissions.go b/core/http/auth/permissions.go index b2408ad4e..1fd9bbc8b 100644 --- a/core/http/auth/permissions.go +++ b/core/http/auth/permissions.go @@ -32,6 +32,9 @@ const ( FeatureCollections = "collections" FeatureMCPJobs = "mcp_jobs" + // General features (default OFF for new users) + FeatureFineTuning = "fine_tuning" + // API features (default ON for new users) FeatureChat = "chat" FeatureImages = "images" @@ -52,6 +55,9 @@ const ( // AgentFeatures lists agent-related features (default OFF). var AgentFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs} +// GeneralFeatures lists general features (default OFF). +var GeneralFeatures = []string{FeatureFineTuning} + // APIFeatures lists API endpoint features (default ON). var APIFeatures = []string{ FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription, @@ -60,7 +66,7 @@ var APIFeatures = []string{ } // AllFeatures lists all known features (used by UI and validation). -var AllFeatures = append(append([]string{}, AgentFeatures...), APIFeatures...) +var AllFeatures = append(append(append([]string{}, AgentFeatures...), GeneralFeatures...), APIFeatures...) // defaultOnFeatures is the set of features that default to ON when absent from a user's permission map. var defaultOnFeatures = func() map[string]bool { diff --git a/core/http/endpoints/localai/agent_collections.go b/core/http/endpoints/localai/agent_collections.go index 022035ef4..dd4bd2370 100644 --- a/core/http/endpoints/localai/agent_collections.go +++ b/core/http/endpoints/localai/agent_collections.go @@ -80,13 +80,14 @@ func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } defer src.Close() - if err := svc.UploadToCollectionForUser(userID, name, file.Filename, src); err != nil { + key, err := svc.UploadToCollectionForUser(userID, name, file.Filename, src) + if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } - return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename}) + return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename, "key": key}) } } diff --git a/core/http/endpoints/localai/finetune.go b/core/http/endpoints/localai/finetune.go new file mode 100644 index 000000000..fe735acb2 --- /dev/null +++ b/core/http/endpoints/localai/finetune.go @@ -0,0 +1,362 @@ +package localai + +import ( + "archive/tar" + "compress/gzip" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services" +) + +// StartFineTuneJobEndpoint starts a new fine-tuning job. +func StartFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + + var req schema.FineTuneJobRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "Invalid request: " + err.Error(), + }) + } + + if req.Model == "" { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "model is required", + }) + } + if req.DatasetSource == "" { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "dataset_source is required", + }) + } + + resp, err := ftService.StartJob(c.Request().Context(), userID, req) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusCreated, resp) + } +} + +// ListFineTuneJobsEndpoint lists fine-tuning jobs for the current user. +func ListFineTuneJobsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobs := ftService.ListJobs(userID) + if jobs == nil { + jobs = []*schema.FineTuneJob{} + } + return c.JSON(http.StatusOK, jobs) + } +} + +// GetFineTuneJobEndpoint gets a specific fine-tuning job. +func GetFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + job, err := ftService.GetJob(userID, jobID) + if err != nil { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, job) + } +} + +// StopFineTuneJobEndpoint stops a running fine-tuning job. +func StopFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + // Check for save_checkpoint query param + saveCheckpoint := c.QueryParam("save_checkpoint") == "true" + + err := ftService.StopJob(c.Request().Context(), userID, jobID, saveCheckpoint) + if err != nil { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]string{ + "status": "stopped", + "message": "Fine-tuning job stopped", + }) + } +} + +// DeleteFineTuneJobEndpoint deletes a fine-tuning job and its data. +func DeleteFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + err := ftService.DeleteJob(userID, jobID) + if err != nil { + status := http.StatusInternalServerError + if strings.Contains(err.Error(), "not found") { + status = http.StatusNotFound + } else if strings.Contains(err.Error(), "cannot delete") { + status = http.StatusConflict + } + return c.JSON(status, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]string{ + "status": "deleted", + "message": "Fine-tuning job deleted", + }) + } +} + +// FineTuneProgressEndpoint streams progress updates via SSE. +func FineTuneProgressEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + // Set SSE headers + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + c.Response().WriteHeader(http.StatusOK) + + err := ftService.StreamProgress(c.Request().Context(), userID, jobID, func(event *schema.FineTuneProgressEvent) { + data, err := json.Marshal(event) + if err != nil { + return + } + fmt.Fprintf(c.Response(), "data: %s\n\n", data) + c.Response().Flush() + }) + if err != nil { + // If headers already sent, we can't send a JSON error + fmt.Fprintf(c.Response(), "data: {\"status\":\"error\",\"message\":%q}\n\n", err.Error()) + c.Response().Flush() + } + + return nil + } +} + +// ListCheckpointsEndpoint lists checkpoints for a job. +func ListCheckpointsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + checkpoints, err := ftService.ListCheckpoints(c.Request().Context(), userID, jobID) + if err != nil { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]any{ + "checkpoints": checkpoints, + }) + } +} + +// ExportModelEndpoint exports a model from a checkpoint. +func ExportModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + var req schema.ExportRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "Invalid request: " + err.Error(), + }) + } + + modelName, err := ftService.ExportModel(c.Request().Context(), userID, jobID, req) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusAccepted, map[string]string{ + "status": "exporting", + "message": "Export started for model '" + modelName + "'", + "model_name": modelName, + }) + } +} + +// DownloadExportedModelEndpoint streams the exported model directory as a tar.gz archive. +func DownloadExportedModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + modelDir, modelName, err := ftService.GetExportedModelPath(userID, jobID) + if err != nil { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": err.Error(), + }) + } + + c.Response().Header().Set("Content-Type", "application/gzip") + c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s.tar.gz"`, modelName)) + c.Response().WriteHeader(http.StatusOK) + + gw := gzip.NewWriter(c.Response()) + defer gw.Close() + + tw := tar.NewWriter(gw) + defer tw.Close() + + err = filepath.Walk(modelDir, func(path string, info os.FileInfo, walkErr error) error { + if walkErr != nil { + return walkErr + } + + relPath, err := filepath.Rel(modelDir, path) + if err != nil { + return err + } + + header, err := tar.FileInfoHeader(info, "") + if err != nil { + return err + } + header.Name = filepath.Join(modelName, relPath) + + if err := tw.WriteHeader(header); err != nil { + return err + } + + if info.IsDir() { + return nil + } + + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + + _, err = io.Copy(tw, f) + return err + }) + + if err != nil { + // Headers already sent, can't return JSON error + return err + } + + return nil + } +} + +// ListFineTuneBackendsEndpoint returns installed backends tagged with "fine-tuning". +func ListFineTuneBackendsEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": "failed to list backends: " + err.Error(), + }) + } + + type backendInfo struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Tags []string `json:"tags,omitempty"` + } + + var result []backendInfo + for _, b := range backends { + if !b.Installed { + continue + } + hasTag := false + for _, t := range b.Tags { + if strings.EqualFold(t, "fine-tuning") { + hasTag = true + break + } + } + if !hasTag { + continue + } + name := b.Name + if b.Alias != "" { + name = b.Alias + } + result = append(result, backendInfo{ + Name: name, + Description: b.Description, + Tags: b.Tags, + }) + } + + if result == nil { + result = []backendInfo{} + } + + return c.JSON(http.StatusOK, result) + } +} + +// UploadDatasetEndpoint handles dataset file upload. +func UploadDatasetEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + file, err := c.FormFile("file") + if err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "file is required", + }) + } + + src, err := file.Open() + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": "failed to open file", + }) + } + defer src.Close() + + data, err := io.ReadAll(src) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": "failed to read file", + }) + } + + path, err := ftService.UploadDataset(file.Filename, data) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]string{ + "path": path, + }) + } +} diff --git a/core/http/react-ui/src/App.css b/core/http/react-ui/src/App.css index cc2fdb8c9..8ed05628c 100644 --- a/core/http/react-ui/src/App.css +++ b/core/http/react-ui/src/App.css @@ -208,6 +208,32 @@ overflow: hidden; } +.sidebar-section-toggle { + display: flex; + align-items: center; + justify-content: space-between; + width: 100%; + background: none; + border: none; + cursor: pointer; + font-family: inherit; + transition: color var(--duration-fast); +} + +.sidebar-section-toggle:hover { + color: var(--color-text-secondary); +} + +.sidebar-section-chevron { + font-size: 0.5rem; + transition: transform var(--duration-fast); + flex-shrink: 0; +} + +.sidebar-section-toggle.open .sidebar-section-chevron { + transform: rotate(90deg); +} + .nav-item { display: flex; align-items: center; @@ -392,6 +418,10 @@ display: none; } +.sidebar.collapsed .sidebar-section-chevron { + display: none; +} + .sidebar.collapsed .nav-item { justify-content: center; padding: 8px 0; @@ -612,14 +642,6 @@ .spinner-md .spinner-ring { width: 24px; height: 24px; } .spinner-lg .spinner-ring { width: 40px; height: 40px; } -.spinner-logo { - animation: pulse 1.2s ease-in-out infinite; - object-fit: contain; -} -.spinner-sm .spinner-logo { width: 16px; height: 16px; } -.spinner-md .spinner-logo { width: 24px; height: 24px; } -.spinner-lg .spinner-logo { width: 40px; height: 40px; } - /* Model selector */ .model-selector { background: var(--color-bg-tertiary); @@ -2646,6 +2668,43 @@ font-size: 0.625rem; } +/* Studio tabs */ +.studio-tabs { + display: flex; + gap: 0; + border-bottom: 1px solid var(--color-border-subtle); + padding: 0 var(--spacing-xl); + background: var(--color-bg-primary); + position: sticky; + top: 0; + z-index: 10; +} + +.studio-tab { + display: flex; + align-items: center; + gap: 6px; + background: none; + border: none; + padding: var(--spacing-sm) var(--spacing-md); + font-size: 0.8125rem; + font-family: inherit; + color: var(--color-text-secondary); + cursor: pointer; + border-bottom: 2px solid transparent; + transition: color var(--duration-fast), border-color var(--duration-fast); +} + +.studio-tab:hover { + color: var(--color-text-primary); +} + +.studio-tab-active { + color: var(--color-primary); + border-bottom-color: var(--color-primary); + font-weight: 500; +} + /* Two-column layout for media generation pages */ .media-layout { display: grid; diff --git a/core/http/react-ui/src/components/LoadingSpinner.jsx b/core/http/react-ui/src/components/LoadingSpinner.jsx index 23f858abe..b1c1b46a2 100644 --- a/core/http/react-ui/src/components/LoadingSpinner.jsx +++ b/core/http/react-ui/src/components/LoadingSpinner.jsx @@ -1,22 +1,8 @@ -import { useState } from 'react' -import { apiUrl } from '../utils/basePath' - export default function LoadingSpinner({ size = 'md', className = '' }) { const sizeClass = size === 'sm' ? 'spinner-sm' : size === 'lg' ? 'spinner-lg' : 'spinner-md' - const [imgFailed, setImgFailed] = useState(false) - return (
- {imgFailed ? ( -
- ) : ( - setImgFailed(true)} - /> - )} +
) } diff --git a/core/http/react-ui/src/components/Sidebar.jsx b/core/http/react-ui/src/components/Sidebar.jsx index eec32f88f..4b0ebc229 100644 --- a/core/http/react-ui/src/components/Sidebar.jsx +++ b/core/http/react-ui/src/components/Sidebar.jsx @@ -1,37 +1,57 @@ import { useState, useEffect } from 'react' -import { NavLink, useNavigate } from 'react-router-dom' +import { NavLink, useNavigate, useLocation } from 'react-router-dom' import ThemeToggle from './ThemeToggle' import { useAuth } from '../context/AuthContext' import { apiUrl } from '../utils/basePath' const COLLAPSED_KEY = 'localai_sidebar_collapsed' +const SECTIONS_KEY = 'localai_sidebar_sections' -const mainItems = [ +const topItems = [ { path: '/app', icon: 'fas fa-home', label: 'Home' }, { path: '/app/models', icon: 'fas fa-download', label: 'Install Models', adminOnly: true }, { path: '/app/chat', icon: 'fas fa-comments', label: 'Chat' }, - { path: '/app/image', icon: 'fas fa-image', label: 'Images' }, - { path: '/app/video', icon: 'fas fa-video', label: 'Video' }, - { path: '/app/tts', icon: 'fas fa-music', label: 'TTS' }, - { path: '/app/sound', icon: 'fas fa-volume-high', label: 'Sound' }, + { path: '/app/studio', icon: 'fas fa-palette', label: 'Studio' }, { path: '/app/talk', icon: 'fas fa-phone', label: 'Talk' }, - { path: '/app/usage', icon: 'fas fa-chart-bar', label: 'Usage', authOnly: true }, ] -const agentItems = [ - { path: '/app/agents', icon: 'fas fa-robot', label: 'Agents' }, - { path: '/app/skills', icon: 'fas fa-wand-magic-sparkles', label: 'Skills' }, - { path: '/app/collections', icon: 'fas fa-database', label: 'Memory' }, - { path: '/app/agent-jobs', icon: 'fas fa-tasks', label: 'MCP CI Jobs', feature: 'mcp' }, -] - -const systemItems = [ - { path: '/app/users', icon: 'fas fa-users', label: 'Users', adminOnly: true, authOnly: true }, - { path: '/app/backends', icon: 'fas fa-server', label: 'Backends', adminOnly: true }, - { path: '/app/traces', icon: 'fas fa-chart-line', label: 'Traces', adminOnly: true }, - { path: '/app/p2p', icon: 'fas fa-circle-nodes', label: 'Swarm', adminOnly: true }, - { path: '/app/manage', icon: 'fas fa-desktop', label: 'System', adminOnly: true }, - { path: '/app/settings', icon: 'fas fa-cog', label: 'Settings', adminOnly: true }, +const sections = [ + { + id: 'tools', + title: 'Tools', + items: [ + { path: '/app/fine-tune', icon: 'fas fa-graduation-cap', label: 'Fine-Tune', feature: 'fine_tuning' }, + ], + }, + { + id: 'agents', + title: 'Agents', + featureMap: { + '/app/agents': 'agents', + '/app/skills': 'skills', + '/app/collections': 'collections', + '/app/agent-jobs': 'mcp_jobs', + }, + items: [ + { path: '/app/agents', icon: 'fas fa-robot', label: 'Agents' }, + { path: '/app/skills', icon: 'fas fa-wand-magic-sparkles', label: 'Skills' }, + { path: '/app/collections', icon: 'fas fa-database', label: 'Memory' }, + { path: '/app/agent-jobs', icon: 'fas fa-tasks', label: 'MCP CI Jobs', feature: 'mcp' }, + ], + }, + { + id: 'system', + title: 'System', + items: [ + { path: '/app/usage', icon: 'fas fa-chart-bar', label: 'Usage', authOnly: true }, + { path: '/app/users', icon: 'fas fa-users', label: 'Users', adminOnly: true, authOnly: true }, + { path: '/app/backends', icon: 'fas fa-server', label: 'Backends', adminOnly: true }, + { path: '/app/traces', icon: 'fas fa-chart-line', label: 'Traces', adminOnly: true }, + { path: '/app/p2p', icon: 'fas fa-circle-nodes', label: 'Swarm', adminOnly: true }, + { path: '/app/manage', icon: 'fas fa-desktop', label: 'System', adminOnly: true }, + { path: '/app/settings', icon: 'fas fa-cog', label: 'Settings', adminOnly: true }, + ], + }, ] function NavItem({ item, onClose, collapsed }) { @@ -51,18 +71,47 @@ function NavItem({ item, onClose, collapsed }) { ) } +function loadSectionState() { + try { + const stored = localStorage.getItem(SECTIONS_KEY) + return stored ? JSON.parse(stored) : {} + } catch (_) { + return {} + } +} + +function saveSectionState(state) { + try { localStorage.setItem(SECTIONS_KEY, JSON.stringify(state)) } catch (_) { /* ignore */ } +} + export default function Sidebar({ isOpen, onClose }) { const [features, setFeatures] = useState({}) const [collapsed, setCollapsed] = useState(() => { try { return localStorage.getItem(COLLAPSED_KEY) === 'true' } catch (_) { return false } }) + const [openSections, setOpenSections] = useState(loadSectionState) const { isAdmin, authEnabled, user, logout, hasFeature } = useAuth() const navigate = useNavigate() + const location = useLocation() useEffect(() => { fetch(apiUrl('/api/features')).then(r => r.json()).then(setFeatures).catch(() => {}) }, []) + // Auto-expand section containing the active route + useEffect(() => { + for (const section of sections) { + const match = section.items.some(item => location.pathname.startsWith(item.path)) + if (match && !openSections[section.id]) { + setOpenSections(prev => { + const next = { ...prev, [section.id]: true } + saveSectionState(next) + return next + }) + } + } + }, [location.pathname]) + const toggleCollapse = () => { setCollapsed(prev => { const next = !prev @@ -72,17 +121,34 @@ export default function Sidebar({ isOpen, onClose }) { }) } - const visibleMainItems = mainItems.filter(item => { - if (item.adminOnly && !isAdmin) return false - if (item.authOnly && !authEnabled) return false - return true - }) + const toggleSection = (id) => { + setOpenSections(prev => { + const next = { ...prev, [id]: !prev[id] } + saveSectionState(next) + return next + }) + } - const visibleSystemItems = systemItems.filter(item => { + const filterItem = (item) => { if (item.adminOnly && !isAdmin) return false if (item.authOnly && !authEnabled) return false + if (item.feature && features[item.feature] === false) return false + if (item.feature && !hasFeature(item.feature)) return false return true - }) + } + + const visibleTopItems = topItems.filter(filterItem) + + const getVisibleSectionItems = (section) => { + return section.items.filter(item => { + if (!filterItem(item)) return false + if (section.featureMap) { + const featureName = section.featureMap[item.path] + return featureName ? hasFeature(featureName) : isAdmin + } + return true + }) + } return ( <> @@ -104,57 +170,57 @@ export default function Sidebar({ isOpen, onClose }) { {/* Navigation */} {/* Footer */} diff --git a/core/http/react-ui/src/pages/FineTune.jsx b/core/http/react-ui/src/pages/FineTune.jsx new file mode 100644 index 000000000..606848754 --- /dev/null +++ b/core/http/react-ui/src/pages/FineTune.jsx @@ -0,0 +1,1525 @@ +import { useState, useEffect, useRef, useCallback } from 'react' +import { fineTuneApi } from '../utils/api' +import LoadingSpinner from '../components/LoadingSpinner' + +const TRAINING_METHODS = ['sft', 'dpo', 'grpo', 'rloo', 'reward', 'kto', 'orpo'] +const TRAINING_TYPES = ['lora', 'loha', 'lokr', 'full'] +const FALLBACK_BACKENDS = ['trl'] +const OPTIMIZERS = ['adamw_torch', 'adamw_8bit', 'sgd', 'adafactor', 'prodigy'] +const MIXED_PRECISION_OPTS = ['', 'fp16', 'bf16', 'no'] + +const BUILTIN_REWARDS = [ + { name: 'format_reward', description: 'Checks ... then answer format', params: [] }, + { name: 'reasoning_accuracy_reward', description: 'Compares content to dataset answer column', params: [] }, + { name: 'length_reward', description: 'Score based on proximity to target length', params: [{ key: 'target_length', default: '200', label: 'Target Length' }] }, + { name: 'xml_tag_reward', description: 'Scores properly opened/closed XML tags', params: [] }, + { name: 'no_repetition_reward', description: 'Penalizes n-gram repetition', params: [] }, + { name: 'code_execution_reward', description: 'Checks Python code block syntax validity', params: [] }, +] + +const statusBadgeClass = { + queued: '', + loading_model: 'badge-warning', + loading_dataset: 'badge-warning', + training: 'badge-info', + saving: 'badge-info', + completed: 'badge-success', + failed: 'badge-error', + stopped: '', +} + +function FormSection({ icon, title, children }) { + return ( +
+

+ + {title} +

+ {children} +
+ ) +} + +function KeyValueEditor({ entries, onChange }) { + const addEntry = () => onChange([...entries, { key: '', value: '' }]) + const removeEntry = (i) => onChange(entries.filter((_, idx) => idx !== i)) + const updateEntry = (i, field, val) => { + const updated = entries.map((e, idx) => idx === i ? { ...e, [field]: val } : e) + onChange(updated) + } + + return ( +
+ {entries.map((entry, i) => ( +
+ updateEntry(i, 'key', e.target.value)} + placeholder="Key" + style={{ flex: 1 }} + /> + updateEntry(i, 'value', e.target.value)} + placeholder="Value" + style={{ flex: 2 }} + /> + +
+ ))} + +
+ ) +} + +function CopyButton({ text }) { + const [copied, setCopied] = useState(false) + const handleCopy = (e) => { + e.stopPropagation() + navigator.clipboard.writeText(text).then(() => { + setCopied(true) + setTimeout(() => setCopied(false), 1500) + }) + } + return ( + + ) +} + +function JobCard({ job, isSelected, onSelect, onUseConfig, onDelete }) { + return ( +
onSelect(job)} + > +
+
+ {job.model} + + {job.backend} / {job.training_method || 'sft'} + +
+
+ + {['completed', 'stopped', 'failed'].includes(job.status) && ( + + )} + + {job.status} + +
+
+
+ ID: {job.id?.slice(0, 8)}... | Created: {job.created_at} +
+ {job.output_dir && ( +
+ + + {job.output_dir} + + +
+ )} + {job.message && ( +
+ + {job.message} +
+ )} +
+ ) +} + +function formatEta(seconds) { + if (!seconds || seconds <= 0) return '--' + const h = Math.floor(seconds / 3600) + const m = Math.floor((seconds % 3600) / 60) + const s = Math.floor(seconds % 60) + if (h > 0) return `${h}h ${m}m` + if (m > 0) return `${m}m ${s}s` + return `${s}s` +} + +function formatAxisValue(val, decimals) { + if (val >= 1) return val.toFixed(Math.min(decimals, 1)) + if (val >= 0.01) return val.toFixed(Math.min(decimals, 3)) + return val.toExponential(1) +} + +function SingleMetricChart({ data, valueKey, label, color, formatValue, events }) { + const [tooltip, setTooltip] = useState(null) + const svgRef = useRef(null) + + if (!data || data.length < 1) return null + + const pad = { top: 16, right: 12, bottom: 32, left: 52 } + const W = 400, H = 220 + const cw = W - pad.left - pad.right + const ch = H - pad.top - pad.bottom + + const steps = data.map(e => e.current_step) + const values = data.map(e => e[valueKey]) + + const minStep = Math.min(...steps), maxStep = Math.max(...steps) + const stepRange = maxStep - minStep || 1 + const minVal = Math.min(...values), maxVal = Math.max(...values) + const valRange = maxVal - minVal || 1 + const valPad = valRange * 0.05 + const yMin = Math.max(0, minVal - valPad), yMax = maxVal + valPad + const yRange = yMax - yMin || 1 + + const x = (step) => pad.left + ((step - minStep) / stepRange) * cw + const y = (val) => pad.top + (1 - (val - yMin) / yRange) * ch + + const points = data.map(e => `${x(e.current_step)},${y(e[valueKey])}`).join(' ') + + const xTickCount = Math.min(5, data.length) + const xTicks = Array.from({ length: xTickCount }, (_, i) => Math.round(minStep + (stepRange * i) / (xTickCount - 1))) + const yTickCount = 4 + const yTicks = Array.from({ length: yTickCount }, (_, i) => yMin + (yRange * i) / (yTickCount - 1)) + + // Epoch boundaries from the full events list if provided + const epochBoundaries = [] + const evts = events || data + for (let i = 1; i < evts.length; i++) { + const prevEpoch = Math.floor(evts[i - 1].current_epoch || 0) + const curEpoch = Math.floor(evts[i].current_epoch || 0) + if (curEpoch > prevEpoch && curEpoch > 0) { + epochBoundaries.push({ step: evts[i].current_step, epoch: curEpoch }) + } + } + + const fmtVal = formatValue || ((v) => formatAxisValue(v, 3)) + + const handleMouseMove = (e) => { + if (!svgRef.current) return + const rect = svgRef.current.getBoundingClientRect() + const mx = ((e.clientX - rect.left) / rect.width) * W + const step = minStep + ((mx - pad.left) / cw) * stepRange + let nearest = data[0], bestDist = Infinity + for (const d of data) { + const dist = Math.abs(d.current_step - step) + if (dist < bestDist) { bestDist = dist; nearest = d } + } + setTooltip({ x: x(nearest.current_step), y: y(nearest[valueKey]), data: nearest }) + } + + return ( +
+
+ + {label} +
+ setTooltip(null)} + > + {yTicks.map((val, i) => ( + + ))} + {epochBoundaries.map((eb, i) => ( + + + + ))} + + + {xTicks.map((step, i) => ( + {step} + ))} + + {yTicks.map((val, i) => ( + {fmtVal(val)} + ))} + Step + {tooltip && ( + + + + + + Step {tooltip.data.current_step} + + + {fmtVal(tooltip.data[valueKey])} + + + )} + +
+ ) +} + +function ChartsGrid({ events }) { + const lossData = events.filter(e => e.loss > 0) + const evalData = events.filter(e => e.eval_loss > 0) + const lrData = events.filter(e => e.learning_rate != null && e.learning_rate > 0) + const gradNormData = events.filter(e => e.grad_norm != null && e.grad_norm > 0) + + const fmtExp = (v) => v.toExponential(1) + + if (lossData.length < 2 && lrData.length < 2 && gradNormData.length < 2) return null + + return ( +
+ + {evalData.length >= 1 ? ( + + ) : ( +
+ + + Eval Loss — waiting for eval data + +
+ )} + + +
+ ) +} + +function TrainingMonitor({ job, onStop }) { + const [events, setEvents] = useState([]) + const [latest, setLatest] = useState(null) + const [connecting, setConnecting] = useState(true) + const eventSourceRef = useRef(null) + + useEffect(() => { + if (!job || !['queued', 'loading_model', 'loading_dataset', 'training', 'saving'].includes(job.status)) { + setConnecting(false) + return + } + + setConnecting(true) + setLatest(null) + setEvents([]) + + const url = fineTuneApi.progressUrl(job.id) + const es = new EventSource(url) + eventSourceRef.current = es + + es.onmessage = (e) => { + try { + setConnecting(false) + const data = JSON.parse(e.data) + setLatest(data) + if (data.loss > 0) { + setEvents(prev => [...prev, data]) + } + if (['completed', 'failed', 'stopped'].includes(data.status)) { + es.close() + } + } catch (_) {} + } + + es.onerror = () => { + setConnecting(false) + es.close() + } + + return () => { + es.close() + } + }, [job?.id]) + + if (!job) return null + + return ( +
+

+ + Training Monitor +

+ + {connecting && !latest && ( +
+ Connecting to training stream... +
+ )} + + {latest && ( +
+
+
Status
+
{latest.status}
+
+
+
Progress
+
{latest.progress_percent?.toFixed(1)}%
+
+
+
Step
+
{latest.current_step} / {latest.total_steps}
+
+
+
Loss
+
{latest.loss?.toFixed(4)}
+
+
+
Epoch
+
{latest.current_epoch?.toFixed(2)} / {latest.total_epochs?.toFixed(0)}
+
+
+
Learning Rate
+
{latest.learning_rate?.toExponential(2)}
+
+
+
ETA
+
{formatEta(latest.eta_seconds)}
+
+ {latest.extra_metrics?.tokens_per_second > 0 && ( +
+
Tokens/sec
+
{latest.extra_metrics.tokens_per_second.toFixed(0)}
+
+ )} +
+ )} + + {/* Progress bar */} + {latest && ( +
+
+
+ )} + + {/* Training charts (2x2 grid) */} + + + {latest?.message && ( +
+ + {latest.message} +
+ )} + + {['queued', 'loading_model', 'loading_dataset', 'training', 'saving'].includes(latest?.status || job.status) && ( + + )} +
+ ) +} + +function CheckpointsPanel({ job, onResume, onExportCheckpoint }) { + const [checkpoints, setCheckpoints] = useState([]) + const [loading, setLoading] = useState(false) + + useEffect(() => { + if (!job) return + setLoading(true) + fineTuneApi.listCheckpoints(job.id).then(r => { + setCheckpoints(r.checkpoints || []) + }).catch(() => {}).finally(() => setLoading(false)) + }, [job?.id]) + + if (!job) return null + if (loading) return
Loading checkpoints...
+ if (checkpoints.length === 0) return null + + return ( +
+

+ + Checkpoints +

+
+ + + + + + + + + + + + + {checkpoints.map(cp => ( + + + + + + + + + ))} + +
StepEpochLossCreatedPathActions
{cp.step}{cp.epoch?.toFixed(2)}{cp.loss?.toFixed(4)}{cp.created_at} + {cp.path} + + + +
+
+
+ ) +} + +const QUANT_PRESETS = ['q4_k_m', 'q5_k_m', 'q8_0', 'f16', 'q4_0', 'q5_0'] + +function ExportPanel({ job, prefilledCheckpoint }) { + const [checkpoints, setCheckpoints] = useState([]) + const [exportFormat, setExportFormat] = useState('lora') + const [quantMethod, setQuantMethod] = useState('q4_k_m') + const [modelName, setModelName] = useState('') + const [selectedCheckpoint, setSelectedCheckpoint] = useState('') + const [exporting, setExporting] = useState(false) + const [message, setMessage] = useState('') + const [exportedModelName, setExportedModelName] = useState('') + const pollRef = useRef(null) + + useEffect(() => { + if (!job) return + fineTuneApi.listCheckpoints(job.id).then(r => { + setCheckpoints(r.checkpoints || []) + }).catch(() => {}) + }, [job?.id]) + + // Apply prefilled checkpoint when set + useEffect(() => { + if (prefilledCheckpoint) { + setSelectedCheckpoint(prefilledCheckpoint.path || '') + } + }, [prefilledCheckpoint]) + + // Sync export state from job (e.g. on initial load or job list refresh) + useEffect(() => { + if (!job) return + if (job.export_status === 'exporting') { + setExporting(true) + setMessage(job.export_message || 'Export in progress...') + } else if (job.export_status === 'completed' && job.export_model_name) { + setExporting(false) + setExportedModelName(job.export_model_name) + setMessage(`Model exported and registered as "${job.export_model_name}"`) + } else if (job.export_status === 'failed') { + setExporting(false) + setMessage(`Export failed: ${job.export_message || 'unknown error'}`) + } + }, [job?.export_status, job?.export_model_name, job?.export_message]) + + // Poll for export completion + useEffect(() => { + if (!exporting || !job) return + + pollRef.current = setInterval(async () => { + try { + const updated = await fineTuneApi.getJob(job.id) + if (updated.export_status === 'completed') { + setExporting(false) + const name = updated.export_model_name || modelName || 'exported model' + setExportedModelName(name) + setMessage(`Model exported and registered as "${name}"`) + clearInterval(pollRef.current) + } else if (updated.export_status === 'failed') { + setExporting(false) + setMessage(`Export failed: ${updated.export_message || 'unknown error'}`) + clearInterval(pollRef.current) + } else if (updated.export_status === 'exporting' && updated.export_message) { + setMessage(updated.export_message) + } + } catch (_) {} + }, 3000) + + return () => clearInterval(pollRef.current) + }, [exporting, job?.id]) + + const handleExport = async () => { + setExporting(true) + setMessage('Export in progress...') + setExportedModelName('') + try { + await fineTuneApi.exportModel(job.id, { + name: modelName || undefined, + checkpoint_path: selectedCheckpoint || job.output_dir, + export_format: exportFormat, + quantization_method: exportFormat === 'gguf' ? quantMethod : '', + model: job.model, + }) + // Polling will pick up completion/failure + } catch (e) { + setMessage(`Export failed: ${e.message}`) + setExporting(false) + } + } + + // Show export panel for completed, stopped, and failed jobs (checkpoints may exist) + if (!job || !['completed', 'stopped', 'failed'].includes(job.status)) return null + + return ( +
+

+ + Export Model +

+ + {checkpoints.length > 0 && ( +
+ + +
+ )} + +
+
+ + +
+ {exportFormat === 'gguf' && ( +
+ + setQuantMethod(e.target.value)} + placeholder="e.g. q4_k_m, bf16, f32" + className="input" + /> + + {QUANT_PRESETS.map(q => ( + +
+ )} +
+ +
+ + setModelName(e.target.value)} + placeholder="e.g. my-finetuned-model" + className="input" + /> +
+ + + + {message && ( +
+ {exporting && } {message} + {exportedModelName && !message.includes('failed') && ( + + + Chat with {exportedModelName} + + + Download Archive + + + )} +
+ )} +
+ ) +} + +export default function FineTune() { + const [jobs, setJobs] = useState([]) + const [selectedJob, setSelectedJob] = useState(null) + const [showForm, setShowForm] = useState(false) + const [loading, setLoading] = useState(false) + const [error, setError] = useState('') + const [backends, setBackends] = useState([]) + const [exportCheckpoint, setExportCheckpoint] = useState(null) + + // Form state + const [model, setModel] = useState('') + const [backend, setBackend] = useState('') + const [trainingMethod, setTrainingMethod] = useState('sft') + const [trainingType, setTrainingType] = useState('lora') + const [datasetSource, setDatasetSource] = useState('') + const [datasetFile, setDatasetFile] = useState(null) + const [datasetSplit, setDatasetSplit] = useState('') + const [numEpochs, setNumEpochs] = useState(3) + const [batchSize, setBatchSize] = useState(2) + const [learningRate, setLearningRate] = useState(0.0002) + const [learningRateText, setLearningRateText] = useState('0.0002') + const [adapterRank, setAdapterRank] = useState(16) + const [adapterAlpha, setAdapterAlpha] = useState(16) + const [adapterDropout, setAdapterDropout] = useState(0) + const [targetModules, setTargetModules] = useState('') + const [gradAccum, setGradAccum] = useState(4) + const [warmupSteps, setWarmupSteps] = useState(5) + const [maxSteps, setMaxSteps] = useState(0) + const [saveSteps, setSaveSteps] = useState(500) + const [weightDecay, setWeightDecay] = useState(0) + const [maxSeqLength, setMaxSeqLength] = useState(2048) + const [optimizer, setOptimizer] = useState('adamw_torch') + const [gradCheckpointing, setGradCheckpointing] = useState(false) + const [seed, setSeed] = useState(0) + const [mixedPrecision, setMixedPrecision] = useState('') + const [extraOptions, setExtraOptions] = useState([]) + const [hfToken, setHfToken] = useState('') + const [showAdvanced, setShowAdvanced] = useState(false) + const [resumeFromCheckpoint, setResumeFromCheckpoint] = useState('') + const [saveTotalLimit, setSaveTotalLimit] = useState(0) + const [evalEnabled, setEvalEnabled] = useState(false) + const [evalStrategy, setEvalStrategy] = useState('steps') + const [evalSteps, setEvalSteps] = useState(0) + const [evalSplit, setEvalSplit] = useState('') + const [evalDatasetSource, setEvalDatasetSource] = useState('') + const [evalSplitRatio, setEvalSplitRatio] = useState(0.1) + const [rewardFunctions, setRewardFunctions] = useState([]) // [{type, name, code?, params?}] + const [showAddCustomReward, setShowAddCustomReward] = useState(false) + const [customRewardName, setCustomRewardName] = useState('') + const [customRewardCode, setCustomRewardCode] = useState('') + + const loadJobs = useCallback(async () => { + try { + const data = await fineTuneApi.listJobs() + setJobs(data || []) + } catch (_) {} + }, []) + + useEffect(() => { + loadJobs() + const interval = setInterval(loadJobs, 10000) + return () => clearInterval(interval) + }, [loadJobs]) + + useEffect(() => { + fineTuneApi.listBackends() + .then(data => { + const names = data && data.length > 0 ? data.map(b => b.name) : FALLBACK_BACKENDS + setBackends(names) + setBackend(prev => prev || names[0] || '') + }) + .catch(() => { + setBackends(FALLBACK_BACKENDS) + setBackend(prev => prev || FALLBACK_BACKENDS[0]) + }) + }, []) + + const handleSubmit = async (e) => { + e.preventDefault() + setLoading(true) + setError('') + + try { + let dsSource = datasetSource + if (datasetFile) { + const result = await fineTuneApi.uploadDataset(datasetFile) + dsSource = result.path + } + + const extra = {} + if (maxSeqLength) extra.max_seq_length = String(maxSeqLength) + if (hfToken.trim()) extra.hf_token = hfToken.trim() + if (saveTotalLimit > 0) extra.save_total_limit = String(saveTotalLimit) + if (evalEnabled) { + extra.eval_strategy = evalStrategy || 'steps' + if (evalSteps > 0) extra.eval_steps = String(evalSteps) + if (evalSplit.trim()) extra.eval_split = evalSplit.trim() + if (evalDatasetSource.trim()) extra.eval_dataset_source = evalDatasetSource.trim() + if (evalSplitRatio > 0 && evalSplitRatio !== 0.1) extra.eval_split_ratio = String(evalSplitRatio) + } else { + extra.eval_strategy = 'no' + } + for (const { key, value } of extraOptions) { + if (key.trim()) extra[key.trim()] = value + } + + const isAdapter = ['lora', 'loha', 'lokr'].includes(trainingType) + + const req = { + model, + backend, + training_method: trainingMethod, + training_type: trainingType, + dataset_source: dsSource, + dataset_split: datasetSplit || undefined, + num_epochs: numEpochs, + batch_size: batchSize, + learning_rate: learningRate, + adapter_rank: isAdapter ? adapterRank : 0, + adapter_alpha: isAdapter ? adapterAlpha : 0, + adapter_dropout: isAdapter && adapterDropout > 0 ? adapterDropout : undefined, + target_modules: isAdapter && targetModules.trim() ? targetModules.split(',').map(s => s.trim()) : undefined, + gradient_accumulation_steps: gradAccum, + warmup_steps: warmupSteps, + max_steps: maxSteps > 0 ? maxSteps : undefined, + save_steps: saveSteps > 0 ? saveSteps : undefined, + weight_decay: weightDecay > 0 ? weightDecay : undefined, + gradient_checkpointing: gradCheckpointing, + optimizer, + seed: seed > 0 ? seed : undefined, + mixed_precision: mixedPrecision || undefined, + resume_from_checkpoint: resumeFromCheckpoint || undefined, + extra_options: Object.keys(extra).length > 0 ? extra : undefined, + reward_functions: trainingMethod === 'grpo' && rewardFunctions.length > 0 ? rewardFunctions : undefined, + } + + const resp = await fineTuneApi.startJob(req) + setShowForm(false) + setResumeFromCheckpoint('') + await loadJobs() + + const newJob = { ...req, id: resp.id, status: 'queued', created_at: new Date().toISOString() } + setSelectedJob(newJob) + } catch (err) { + setError(err.message) + } + setLoading(false) + } + + const handleStop = async (jobId) => { + try { + await fineTuneApi.stopJob(jobId, true) + await loadJobs() + } catch (err) { + setError(err.message) + } + } + + const handleDelete = async (jobId) => { + if (!window.confirm('Delete this job and all its data (checkpoints, exported model)? This cannot be undone.')) return + try { + await fineTuneApi.deleteJob(jobId) + if (selectedJob?.id === jobId) setSelectedJob(null) + await loadJobs() + } catch (err) { + setError(err.message) + } + } + + const isAdapter = ['lora', 'loha', 'lokr'].includes(trainingType) + + const getFormConfig = () => { + const extra = {} + for (const { key, value } of extraOptions) { + if (key.trim()) extra[key.trim()] = value + } + return { + model, + backend, + training_method: trainingMethod, + training_type: trainingType, + adapter_rank: adapterRank, + adapter_alpha: adapterAlpha, + adapter_dropout: adapterDropout, + target_modules: targetModules.trim() ? targetModules.split(',').map(s => s.trim()) : [], + dataset_source: datasetSource, + dataset_split: datasetSplit, + num_epochs: numEpochs, + batch_size: batchSize, + learning_rate: learningRate, + gradient_accumulation_steps: gradAccum, + warmup_steps: warmupSteps, + max_steps: maxSteps, + save_steps: saveSteps, + weight_decay: weightDecay, + gradient_checkpointing: gradCheckpointing, + optimizer, + seed, + mixed_precision: mixedPrecision, + max_seq_length: maxSeqLength, + eval_strategy: evalEnabled ? (evalStrategy || 'steps') : 'no', + eval_steps: evalSteps, + eval_split: evalSplit, + eval_dataset_source: evalDatasetSource, + eval_split_ratio: evalSplitRatio, + extra_options: Object.keys(extra).length > 0 ? extra : {}, + reward_functions: rewardFunctions.length > 0 ? rewardFunctions : undefined, + } + } + + const applyFormConfig = (config) => { + if (config.model != null) setModel(config.model) + if (config.backend != null) setBackend(config.backend) + if (config.training_method != null) setTrainingMethod(config.training_method) + if (config.training_type != null) setTrainingType(config.training_type) + if (config.adapter_rank != null) setAdapterRank(Number(config.adapter_rank)) + if (config.adapter_alpha != null) setAdapterAlpha(Number(config.adapter_alpha)) + if (config.adapter_dropout != null) setAdapterDropout(Number(config.adapter_dropout)) + if (config.target_modules != null) { + const modules = Array.isArray(config.target_modules) + ? config.target_modules.join(', ') + : String(config.target_modules) + setTargetModules(modules) + } + if (config.dataset_source != null) setDatasetSource(config.dataset_source) + if (config.dataset_split != null) setDatasetSplit(config.dataset_split) + if (config.num_epochs != null) setNumEpochs(Number(config.num_epochs)) + if (config.batch_size != null) setBatchSize(Number(config.batch_size)) + if (config.learning_rate != null) { setLearningRate(Number(config.learning_rate)); setLearningRateText(String(config.learning_rate)) } + if (config.gradient_accumulation_steps != null) setGradAccum(Number(config.gradient_accumulation_steps)) + if (config.warmup_steps != null) setWarmupSteps(Number(config.warmup_steps)) + if (config.max_steps != null) setMaxSteps(Number(config.max_steps)) + if (config.save_steps != null) setSaveSteps(Number(config.save_steps)) + if (config.weight_decay != null) setWeightDecay(Number(config.weight_decay)) + if (config.gradient_checkpointing != null) setGradCheckpointing(Boolean(config.gradient_checkpointing)) + if (config.optimizer != null) setOptimizer(config.optimizer) + if (config.seed != null) setSeed(Number(config.seed)) + if (config.mixed_precision != null) setMixedPrecision(config.mixed_precision) + + // Handle max_seq_length: top-level field or inside extra_options + if (config.max_seq_length != null) { + setMaxSeqLength(Number(config.max_seq_length)) + } else if (config.extra_options?.max_seq_length != null) { + setMaxSeqLength(Number(config.extra_options.max_seq_length)) + } + + // Eval options — detect enabled state from strategy + const restoreEval = (strategy, steps, split, src, ratio) => { + if (strategy != null && strategy !== 'no') { + setEvalEnabled(true) + setEvalStrategy(strategy) + } else if (strategy === 'no') { + setEvalEnabled(false) + } + if (steps != null) setEvalSteps(Number(steps)) + if (split != null) setEvalSplit(split) + if (src != null) setEvalDatasetSource(src) + if (ratio != null) setEvalSplitRatio(Number(ratio)) + } + restoreEval(config.eval_strategy, config.eval_steps, config.eval_split, config.eval_dataset_source, config.eval_split_ratio) + // Also restore from extra_options if present (overrides top-level) + const eo = config.extra_options + if (eo) restoreEval(eo.eval_strategy, eo.eval_steps, eo.eval_split, eo.eval_dataset_source, eo.eval_split_ratio) + + // Handle save_total_limit from extra_options + if (config.extra_options?.save_total_limit != null) { + setSaveTotalLimit(Number(config.extra_options.save_total_limit)) + } + + // Convert extra_options object to [{key, value}] entries, filtering out handled keys + if (config.extra_options && typeof config.extra_options === 'object') { + const entries = Object.entries(config.extra_options) + .filter(([k]) => !['max_seq_length', 'save_total_limit', 'hf_token', 'eval_strategy', 'eval_steps', 'eval_split', 'eval_dataset_source', 'eval_split_ratio'].includes(k)) + .map(([key, value]) => ({ key, value: String(value) })) + setExtraOptions(entries) + } + + // Restore reward functions + if (Array.isArray(config.reward_functions)) { + setRewardFunctions(config.reward_functions) + } else { + setRewardFunctions([]) + } + } + + const handleExportConfig = () => { + const config = getFormConfig() + const json = JSON.stringify(config, null, 2) + const blob = new Blob([json], { type: 'application/json' }) + const url = URL.createObjectURL(blob) + const a = document.createElement('a') + a.href = url + a.download = 'finetune-config.json' + document.body.appendChild(a) + a.click() + document.body.removeChild(a) + URL.revokeObjectURL(url) + } + + const handleImportConfig = () => { + const input = document.createElement('input') + input.type = 'file' + input.accept = '.json' + input.onchange = (e) => { + const file = e.target.files[0] + if (!file) return + const reader = new FileReader() + reader.onload = (ev) => { + try { + const config = JSON.parse(ev.target.result) + applyFormConfig(config) + setShowForm(true) + setError('') + } catch { + setError('Failed to parse config file. Please ensure it is valid JSON.') + } + } + reader.readAsText(file) + } + input.click() + } + + const handleUseConfig = (job) => { + // Prefer the stored config if available, otherwise use the job fields + applyFormConfig(job.config || job) + setResumeFromCheckpoint('') + setShowForm(true) + } + + const handleResumeFromCheckpoint = (checkpoint) => { + if (!selectedJob) return + // Apply the original job's config + applyFormConfig(selectedJob.config || selectedJob) + setResumeFromCheckpoint(checkpoint.path) + setShowAdvanced(true) + setShowForm(true) + } + + const handleExportCheckpoint = (checkpoint) => { + setExportCheckpoint(checkpoint) + } + + return ( +
+
+
+

Fine-Tuning

+

Create and manage fine-tuning jobs

+
+
+ + +
+
+ + {error && ( +
+ {error} +
+ )} + + {showForm && ( +
+ + {resumeFromCheckpoint && ( +
+ + + Resuming from checkpoint: {resumeFromCheckpoint} + + +
+ )} + + +
+
+ + +
+
+ + +
+
+ + setModel(e.target.value)} placeholder="e.g. TinyLlama/TinyLlama-1.1B-Chat-v1.0" className="input" required /> +
+
+
+ + setHfToken(e.target.value)} placeholder="hf_..." className="input" /> +
+
+ + +
+
+ + +
+ {isAdapter && ( + <> +
+ + setAdapterRank(Number(e.target.value))} className="input" min={1} /> +
+
+ + setAdapterAlpha(Number(e.target.value))} className="input" min={1} /> +
+
+ + setAdapterDropout(Number(e.target.value))} className="input" min={0} max={1} step={0.05} /> +
+ + )} +
+ {isAdapter && ( +
+ + setTargetModules(e.target.value)} placeholder="e.g. q_proj, v_proj, k_proj, o_proj" className="input" /> +
+ )} +
+ + +
+
+ + setDatasetSource(e.target.value)} placeholder="e.g. tatsu-lab/alpaca" className="input" /> +
+
+ + setDatasetSplit(e.target.value)} placeholder="e.g. train" className="input" /> +
+
+ + setDatasetFile(e.target.files[0])} accept=".json,.jsonl,.csv" className="input" style={{ padding: '6px' }} /> +
+
+
+ + {trainingMethod === 'grpo' && ( + +
+ GRPO requires at least one reward function. Select built-in functions or add custom ones. +
+ + {/* Built-in reward functions */} +
+ {BUILTIN_REWARDS.map(builtin => { + const isSelected = rewardFunctions.some(rf => rf.type === 'builtin' && rf.name === builtin.name) + const selectedRf = rewardFunctions.find(rf => rf.type === 'builtin' && rf.name === builtin.name) + return ( +
+ + {isSelected && builtin.params.length > 0 && ( +
+ {builtin.params.map(param => ( +
+ + { + setRewardFunctions(prev => prev.map(rf => + rf.type === 'builtin' && rf.name === builtin.name + ? { ...rf, params: { ...(rf.params || {}), [param.key]: e.target.value } } + : rf + )) + }} + /> +
+ ))} +
+ )} +
+ ) + })} +
+ + {/* Custom inline reward functions */} + {rewardFunctions.filter(rf => rf.type === 'inline').map((rf, idx) => ( +
+
+ + + {rf.name} + + +
+
+                    {rf.code}
+                  
+
+ ))} + + {/* Add custom reward button / form */} + {showAddCustomReward ? ( +
+
+ + setCustomRewardName(e.target.value)} + placeholder="e.g. my_custom_reward" style={{ maxWidth: '300px' }} /> +
+
+ +