mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 13:15:51 -04:00
feat: add (experimental) fine-tuning support with TRL (#9088)
* feat: add fine-tuning endpoint Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(experimental): add fine-tuning endpoint and TRL support This changeset defines new GRPC signatues for Fine tuning backends, and add TRL backend as initial fine-tuning engine. This implementation also supports exporting to GGUF and automatically importing it to LocalAI after fine-tuning. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * commit TRL backend, stop by killing process Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * move fine-tune to generic features Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * add evals, reorder menu Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Fix tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
f7e3aab4fc
commit
d9c1db2b87
141
.agents/debugging-backends.md
Normal file
141
.agents/debugging-backends.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# Debugging and Rebuilding Backends
|
||||
|
||||
When a backend fails at runtime (e.g. a gRPC method error, a Python import error, or a dependency conflict), use this guide to diagnose, fix, and rebuild.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
- **Source directory**: `backend/python/<name>/` (or `backend/go/<name>/`, `backend/cpp/<name>/`)
|
||||
- **Installed directory**: `backends/<name>/` — this is what LocalAI actually runs. It is populated by `make backends/<name>` which builds a Docker image, exports it, and installs it via `local-ai backends install`.
|
||||
- **Virtual environment**: `backends/<name>/venv/` — the installed Python venv (for Python backends). The Python binary is at `backends/<name>/venv/bin/python`.
|
||||
|
||||
Editing files in `backend/python/<name>/` does **not** affect the running backend until you rebuild with `make backends/<name>`.
|
||||
|
||||
## Diagnosing Failures
|
||||
|
||||
### 1. Check the logs
|
||||
|
||||
Backend gRPC processes log to LocalAI's stdout/stderr. Look for lines tagged with the backend's model ID:
|
||||
|
||||
```
|
||||
GRPC stderr id="trl-finetune-127.0.0.1:37335" line="..."
|
||||
```
|
||||
|
||||
Common error patterns:
|
||||
- **"Method not implemented"** — the backend is missing a gRPC method that the Go side calls. The model loader (`pkg/model/initializers.go`) always calls `LoadModel` after `Health`; fine-tuning backends must implement it even as a no-op stub.
|
||||
- **Python import errors / `AttributeError`** — usually a dependency version mismatch (e.g. `pyarrow` removing `PyExtensionType`).
|
||||
- **"failed to load backend"** — the gRPC process crashed or never started. Check stderr lines for the traceback.
|
||||
|
||||
### 2. Test the Python environment directly
|
||||
|
||||
You can run the installed venv's Python to check imports without starting the full server:
|
||||
|
||||
```bash
|
||||
backends/<name>/venv/bin/python -c "import datasets; print(datasets.__version__)"
|
||||
```
|
||||
|
||||
If `pip` is missing from the venv, bootstrap it:
|
||||
|
||||
```bash
|
||||
backends/<name>/venv/bin/python -m ensurepip
|
||||
```
|
||||
|
||||
Then use `backends/<name>/venv/bin/python -m pip install ...` to test fixes in the installed venv before committing them to the source requirements.
|
||||
|
||||
### 3. Check upstream dependency constraints
|
||||
|
||||
When you hit a dependency conflict, check what the main library expects. For example, TRL's upstream `requirements.txt`:
|
||||
|
||||
```
|
||||
https://github.com/huggingface/trl/blob/main/requirements.txt
|
||||
```
|
||||
|
||||
Pin minimum versions in the backend's requirements files to match upstream.
|
||||
|
||||
## Common Fixes
|
||||
|
||||
### Missing gRPC methods
|
||||
|
||||
If the Go side calls a method the backend doesn't implement (e.g. `LoadModel`), add a no-op stub in `backend.py`:
|
||||
|
||||
```python
|
||||
def LoadModel(self, request, context):
|
||||
"""No-op — actual loading happens elsewhere."""
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
```
|
||||
|
||||
The gRPC contract requires `LoadModel` to succeed for the model loader to return a usable client, even if the backend doesn't need upfront model loading.
|
||||
|
||||
### Dependency version conflicts
|
||||
|
||||
Python backends often break when a transitive dependency releases a breaking change (e.g. `pyarrow` removing `PyExtensionType`). Steps:
|
||||
|
||||
1. Identify the broken import in the logs
|
||||
2. Test in the installed venv: `backends/<name>/venv/bin/python -c "import <module>"`
|
||||
3. Check upstream requirements for version constraints
|
||||
4. Update **all** requirements files in `backend/python/<name>/`:
|
||||
- `requirements.txt` — base deps (grpcio, protobuf)
|
||||
- `requirements-cpu.txt` — CPU-specific (includes PyTorch CPU index)
|
||||
- `requirements-cublas12.txt` — CUDA 12
|
||||
- `requirements-cublas13.txt` — CUDA 13
|
||||
5. Rebuild: `make backends/<name>`
|
||||
|
||||
### PyTorch index conflicts (uv resolver)
|
||||
|
||||
The Docker build uses `uv` for pip installs. When `--extra-index-url` points to the PyTorch wheel index, `uv` may refuse to fetch packages like `requests` from PyPI if it finds a different version on the PyTorch index first. Fix this by adding `--index-strategy=unsafe-first-match` to `install.sh`:
|
||||
|
||||
```bash
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
installRequirements
|
||||
```
|
||||
|
||||
Most Python backends already do this — check `backend/python/transformers/install.sh` or similar for reference.
|
||||
|
||||
## Rebuilding
|
||||
|
||||
### Rebuild a single backend
|
||||
|
||||
```bash
|
||||
make backends/<name>
|
||||
```
|
||||
|
||||
This runs the Docker build (`Dockerfile.python`), exports the image to `backend-images/<name>.tar`, and installs it into `backends/<name>/`. It also rebuilds the `local-ai` Go binary (without extra tags).
|
||||
|
||||
**Important**: If you were previously running with `GO_TAGS=auth`, the `make backends/<name>` step will overwrite your binary without that tag. Rebuild the Go binary afterward:
|
||||
|
||||
```bash
|
||||
GO_TAGS=auth make build
|
||||
```
|
||||
|
||||
### Rebuild and restart
|
||||
|
||||
After rebuilding a backend, you must restart LocalAI for it to pick up the new backend files. The backend gRPC process is spawned on demand when the model is first loaded.
|
||||
|
||||
```bash
|
||||
# Kill existing process
|
||||
kill <pid>
|
||||
|
||||
# Restart
|
||||
./local-ai run --debug [your flags]
|
||||
```
|
||||
|
||||
### Quick iteration (skip Docker rebuild)
|
||||
|
||||
For fast iteration on a Python backend's `backend.py` without a full Docker rebuild, you can edit the installed copy directly:
|
||||
|
||||
```bash
|
||||
# Edit the installed copy
|
||||
vim backends/<name>/backend.py
|
||||
|
||||
# Restart LocalAI to respawn the gRPC process
|
||||
```
|
||||
|
||||
This is useful for testing but **does not persist** — the next `make backends/<name>` will overwrite it. Always commit fixes to the source in `backend/python/<name>/`.
|
||||
|
||||
## Verification
|
||||
|
||||
After fixing and rebuilding:
|
||||
|
||||
1. Start LocalAI and confirm the backend registers: look for `Registering backend name="<name>"` in the logs
|
||||
2. Trigger the operation that failed (e.g. start a fine-tuning job)
|
||||
3. Watch the GRPC stderr/stdout lines for the backend's model ID
|
||||
4. Confirm no errors in the traceback
|
||||
39
.github/workflows/backend.yml
vendored
39
.github/workflows/backend.yml
vendored
@@ -118,6 +118,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-trl'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "trl"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
@@ -366,6 +379,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-trl'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "trl"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -757,6 +783,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-trl'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "trl"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
|
||||
@@ -12,6 +12,7 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read
|
||||
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
|
||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||
|
||||
## Quick Reference
|
||||
|
||||
|
||||
8
Makefile
8
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -421,6 +421,7 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/voxcpm
|
||||
$(MAKE) -C backend/python/whisperx
|
||||
$(MAKE) -C backend/python/ace-step
|
||||
$(MAKE) -C backend/python/trl
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/transformers test
|
||||
@@ -440,6 +441,7 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/voxcpm test
|
||||
$(MAKE) -C backend/python/whisperx test
|
||||
$(MAKE) -C backend/python/ace-step test
|
||||
$(MAKE) -C backend/python/trl test
|
||||
|
||||
DOCKER_IMAGE?=local-ai
|
||||
IMAGE_TYPE?=core
|
||||
@@ -572,6 +574,7 @@ BACKEND_VOXCPM = voxcpm|python|.|false|true
|
||||
BACKEND_WHISPERX = whisperx|python|.|false|true
|
||||
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
||||
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||
BACKEND_TRL = trl|python|.|false|true
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
@@ -629,12 +632,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||
|
||||
# Pattern rule for docker-save targets
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
@@ -39,6 +39,13 @@ service Backend {
|
||||
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
|
||||
|
||||
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
|
||||
|
||||
// Fine-tuning RPCs
|
||||
rpc StartFineTune(FineTuneRequest) returns (FineTuneJobResult) {}
|
||||
rpc FineTuneProgress(FineTuneProgressRequest) returns (stream FineTuneProgressUpdate) {}
|
||||
rpc StopFineTune(FineTuneStopRequest) returns (Result) {}
|
||||
rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {}
|
||||
rpc ExportModel(ExportModelRequest) returns (Result) {}
|
||||
}
|
||||
|
||||
// Define the empty request
|
||||
@@ -528,3 +535,105 @@ message ModelMetadataResponse {
|
||||
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
||||
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
|
||||
}
|
||||
|
||||
// Fine-tuning messages
|
||||
|
||||
message FineTuneRequest {
|
||||
// Model identification
|
||||
string model = 1; // HF model name or local path
|
||||
string training_type = 2; // "lora", "loha", "lokr", "full" — what parameters to train
|
||||
string training_method = 3; // "sft", "dpo", "grpo", "rloo", "reward", "kto", "orpo", "network_training"
|
||||
|
||||
// Adapter config (universal across LoRA/LoHa/LoKr for LLM + diffusion)
|
||||
int32 adapter_rank = 10; // LoRA rank (r), default 16
|
||||
int32 adapter_alpha = 11; // scaling factor, default 16
|
||||
float adapter_dropout = 12; // default 0.0
|
||||
repeated string target_modules = 13; // layer names to adapt
|
||||
|
||||
// Universal training hyperparameters
|
||||
float learning_rate = 20; // default 2e-4
|
||||
int32 num_epochs = 21; // default 3
|
||||
int32 batch_size = 22; // default 2
|
||||
int32 gradient_accumulation_steps = 23; // default 4
|
||||
int32 warmup_steps = 24; // default 5
|
||||
int32 max_steps = 25; // 0 = use epochs
|
||||
int32 save_steps = 26; // 0 = only save final
|
||||
float weight_decay = 27; // default 0.01
|
||||
bool gradient_checkpointing = 28;
|
||||
string optimizer = 29; // adamw_8bit, adamw, sgd, adafactor, prodigy
|
||||
int32 seed = 30; // default 3407
|
||||
string mixed_precision = 31; // fp16, bf16, fp8, no
|
||||
|
||||
// Dataset
|
||||
string dataset_source = 40; // HF dataset ID, local file/dir path
|
||||
string dataset_split = 41; // train, test, etc.
|
||||
|
||||
// Output
|
||||
string output_dir = 50;
|
||||
string job_id = 51; // client-assigned or auto-generated
|
||||
|
||||
// Resume training from a checkpoint
|
||||
string resume_from_checkpoint = 55; // path to checkpoint dir to resume from
|
||||
|
||||
// Backend-specific AND method-specific extensibility
|
||||
map<string, string> extra_options = 60;
|
||||
}
|
||||
|
||||
message FineTuneJobResult {
|
||||
string job_id = 1;
|
||||
bool success = 2;
|
||||
string message = 3;
|
||||
}
|
||||
|
||||
message FineTuneProgressRequest {
|
||||
string job_id = 1;
|
||||
}
|
||||
|
||||
message FineTuneProgressUpdate {
|
||||
string job_id = 1;
|
||||
int32 current_step = 2;
|
||||
int32 total_steps = 3;
|
||||
float current_epoch = 4;
|
||||
float total_epochs = 5;
|
||||
float loss = 6;
|
||||
float learning_rate = 7;
|
||||
float grad_norm = 8;
|
||||
float eval_loss = 9;
|
||||
float eta_seconds = 10;
|
||||
float progress_percent = 11;
|
||||
string status = 12; // queued, caching, loading_model, loading_dataset, training, saving, completed, failed, stopped
|
||||
string message = 13;
|
||||
string checkpoint_path = 14; // set when a checkpoint is saved
|
||||
string sample_path = 15; // set when a sample is generated (video/image backends)
|
||||
map<string, float> extra_metrics = 16; // method-specific metrics
|
||||
}
|
||||
|
||||
message FineTuneStopRequest {
|
||||
string job_id = 1;
|
||||
bool save_checkpoint = 2;
|
||||
}
|
||||
|
||||
message ListCheckpointsRequest {
|
||||
string output_dir = 1;
|
||||
}
|
||||
|
||||
message ListCheckpointsResponse {
|
||||
repeated CheckpointInfo checkpoints = 1;
|
||||
}
|
||||
|
||||
message CheckpointInfo {
|
||||
string path = 1;
|
||||
int32 step = 2;
|
||||
float epoch = 3;
|
||||
float loss = 4;
|
||||
string created_at = 5;
|
||||
}
|
||||
|
||||
message ExportModelRequest {
|
||||
string checkpoint_path = 1;
|
||||
string output_path = 2;
|
||||
string export_format = 3; // lora, loha, lokr, merged_16bit, merged_4bit, gguf, diffusers
|
||||
string quantization_method = 4; // for GGUF: q4_k_m, q5_k_m, q8_0, f16, etc.
|
||||
string model = 5; // base model name (for merge operations)
|
||||
map<string, string> extra_options = 6;
|
||||
}
|
||||
|
||||
@@ -3030,3 +3030,54 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-voxtral
|
||||
- &trl
|
||||
name: "trl"
|
||||
alias: "trl"
|
||||
license: apache-2.0
|
||||
description: |
|
||||
HuggingFace TRL fine-tuning backend. Supports SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO training methods.
|
||||
Works on CPU and GPU.
|
||||
urls:
|
||||
- https://github.com/huggingface/trl
|
||||
tags:
|
||||
- fine-tuning
|
||||
- LLM
|
||||
- CPU
|
||||
- GPU
|
||||
- CUDA
|
||||
capabilities:
|
||||
default: "cpu-trl"
|
||||
nvidia: "cuda12-trl"
|
||||
nvidia-cuda-12: "cuda12-trl"
|
||||
nvidia-cuda-13: "cuda13-trl"
|
||||
## TRL backend images
|
||||
- !!merge <<: *trl
|
||||
name: "cpu-trl"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cpu-trl-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda12-trl"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda12-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cublas-cuda12-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda12-trl-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda12-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cublas-cuda12-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda13-trl"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda13-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cublas-cuda13-trl
|
||||
- !!merge <<: *trl
|
||||
name: "cuda13-trl-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cublas-cuda13-trl
|
||||
|
||||
26
backend/python/trl/Makefile
Normal file
26
backend/python/trl/Makefile
Normal file
@@ -0,0 +1,26 @@
|
||||
# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export)
|
||||
LLAMA_CPP_CONVERT_VERSION ?= master
|
||||
|
||||
.PHONY: trl
|
||||
trl:
|
||||
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: trl
|
||||
@echo "Running trl..."
|
||||
bash run.sh
|
||||
@echo "trl run."
|
||||
|
||||
.PHONY: test
|
||||
test: trl
|
||||
@echo "Testing trl..."
|
||||
bash test.sh
|
||||
@echo "trl tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
860
backend/python/trl/backend.py
Normal file
860
backend/python/trl/backend.py
Normal file
@@ -0,0 +1,860 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TRL fine-tuning backend for LocalAI.
|
||||
|
||||
Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO)
|
||||
using standard HuggingFace transformers + PEFT. Works on both CPU and GPU.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent import futures
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
|
||||
|
||||
|
||||
class ProgressCallback:
|
||||
"""HuggingFace TrainerCallback that pushes progress updates to a queue."""
|
||||
|
||||
def __init__(self, job_id, progress_queue, total_epochs):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = progress_queue
|
||||
self.total_epochs = total_epochs
|
||||
|
||||
def get_callback(self):
|
||||
from transformers import TrainerCallback
|
||||
|
||||
parent = self
|
||||
|
||||
class _Callback(TrainerCallback):
|
||||
def __init__(self):
|
||||
self._train_start_time = None
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
self._train_start_time = time.time()
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if logs is None:
|
||||
return
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
eta = 0.0
|
||||
if state.global_step > 0 and total_steps > 0 and self._train_start_time:
|
||||
elapsed = time.time() - self._train_start_time
|
||||
remaining_steps = total_steps - state.global_step
|
||||
if state.global_step > 0:
|
||||
eta = remaining_steps * (elapsed / state.global_step)
|
||||
|
||||
extra_metrics = {}
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
|
||||
extra_metrics[k] = float(v)
|
||||
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(logs.get('epoch', 0)),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
loss=float(logs.get('loss', 0)),
|
||||
learning_rate=float(logs.get('learning_rate', 0)),
|
||||
grad_norm=float(logs.get('grad_norm', 0)),
|
||||
eval_loss=float(logs.get('eval_loss', 0)),
|
||||
eta_seconds=float(eta),
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
extra_metrics=extra_metrics,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_prediction_step(self, args, state, control, **kwargs):
|
||||
"""Send periodic updates during evaluation so the UI doesn't freeze."""
|
||||
if not hasattr(self, '_eval_update_counter'):
|
||||
self._eval_update_counter = 0
|
||||
self._eval_update_counter += 1
|
||||
# Throttle: send an update every 10 prediction steps
|
||||
if self._eval_update_counter % 10 != 0:
|
||||
return
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(state.epoch or 0),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
message=f"Evaluating... (batch {self._eval_update_counter})",
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||
"""Report eval results once evaluation is done."""
|
||||
# Reset prediction counter for next eval round
|
||||
self._eval_update_counter = 0
|
||||
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
|
||||
eval_loss = 0.0
|
||||
extra_metrics = {}
|
||||
if metrics:
|
||||
eval_loss = float(metrics.get('eval_loss', 0))
|
||||
for k, v in metrics.items():
|
||||
if isinstance(v, (int, float)) and k not in ('eval_loss', 'epoch'):
|
||||
extra_metrics[k] = float(v)
|
||||
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(state.epoch or 0),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
eval_loss=eval_loss,
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
message=f"Evaluation complete at step {state.global_step}",
|
||||
extra_metrics=extra_metrics,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
status="saving",
|
||||
message=f"Checkpoint saved at step {state.global_step}",
|
||||
checkpoint_path=checkpoint_path,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=state.max_steps,
|
||||
progress_percent=100.0,
|
||||
status="completed",
|
||||
message="Training completed",
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
return _Callback()
|
||||
|
||||
|
||||
class ActiveJob:
|
||||
"""Represents an active fine-tuning job."""
|
||||
|
||||
def __init__(self, job_id):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = queue.Queue()
|
||||
self.trainer = None
|
||||
self.thread = None
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.error = None
|
||||
self.completed = False
|
||||
self.stopped = False
|
||||
|
||||
|
||||
def _is_gated_repo_error(exc):
|
||||
"""Check if an exception is caused by a gated HuggingFace repo requiring authentication."""
|
||||
try:
|
||||
from huggingface_hub.utils import GatedRepoError
|
||||
if isinstance(exc, GatedRepoError):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
msg = str(exc).lower()
|
||||
if "gated repo" in msg or "access to model" in msg:
|
||||
return True
|
||||
if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'):
|
||||
if exc.response.status_code in (401, 403):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self):
|
||||
self.active_job = None
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Accept LoadModel — actual model loading happens in StartFineTune."""
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def StartFineTune(self, request, context):
|
||||
if self.active_job is not None and not self.active_job.completed:
|
||||
return backend_pb2.FineTuneJobResult(
|
||||
job_id="",
|
||||
success=False,
|
||||
message="A fine-tuning job is already running",
|
||||
)
|
||||
|
||||
job_id = request.job_id if request.job_id else str(uuid.uuid4())
|
||||
job = ActiveJob(job_id)
|
||||
self.active_job = job
|
||||
|
||||
# Start training in background thread
|
||||
thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
|
||||
job.thread = thread
|
||||
thread.start()
|
||||
|
||||
return backend_pb2.FineTuneJobResult(
|
||||
job_id=job_id,
|
||||
success=True,
|
||||
message="Fine-tuning job started",
|
||||
)
|
||||
|
||||
def _run_training(self, request, job):
|
||||
try:
|
||||
self._do_training(request, job)
|
||||
except Exception as e:
|
||||
if _is_gated_repo_error(e):
|
||||
msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||
else:
|
||||
msg = f"Training failed: {e}"
|
||||
job.error = msg
|
||||
job.completed = True
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id,
|
||||
status="failed",
|
||||
message=msg,
|
||||
)
|
||||
job.progress_queue.put(update)
|
||||
# Send sentinel
|
||||
job.progress_queue.put(None)
|
||||
|
||||
def _do_training(self, request, job):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
extra = dict(request.extra_options)
|
||||
training_method = request.training_method or "sft"
|
||||
training_type = request.training_type or "lora"
|
||||
|
||||
# Send loading status
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
|
||||
))
|
||||
|
||||
# Determine device and dtype
|
||||
device_map = "auto" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
|
||||
|
||||
# HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
|
||||
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
|
||||
# Load model
|
||||
model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
|
||||
if hf_token:
|
||||
model_kwargs["token"] = hf_token
|
||||
if extra.get("trust_remote_code", "false").lower() == "true":
|
||||
model_kwargs["trust_remote_code"] = True
|
||||
if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
|
||||
from transformers import BitsAndBytesConfig
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
job.model = model
|
||||
job.tokenizer = tokenizer
|
||||
|
||||
# Apply LoRA if requested
|
||||
if training_type == "lora":
|
||||
from peft import LoraConfig, get_peft_model
|
||||
lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
|
||||
lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
|
||||
lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
|
||||
|
||||
target_modules = list(request.target_modules) if request.target_modules else None
|
||||
peft_config = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=target_modules or "all-linear",
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
# Load dataset
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="loading_dataset", message="Loading dataset",
|
||||
))
|
||||
|
||||
dataset_split = request.dataset_split or "train"
|
||||
if os.path.exists(request.dataset_source):
|
||||
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||
dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split)
|
||||
elif request.dataset_source.endswith('.csv'):
|
||||
dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split)
|
||||
else:
|
||||
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||
else:
|
||||
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||
|
||||
# Eval dataset setup
|
||||
eval_dataset = None
|
||||
eval_strategy = extra.get("eval_strategy", "steps")
|
||||
eval_steps = int(extra.get("eval_steps", str(request.save_steps if request.save_steps > 0 else 500)))
|
||||
|
||||
if eval_strategy != "no":
|
||||
eval_split = extra.get("eval_split")
|
||||
eval_dataset_source = extra.get("eval_dataset_source")
|
||||
if eval_split:
|
||||
# Load a specific split as eval dataset
|
||||
if os.path.exists(request.dataset_source):
|
||||
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||
eval_dataset = load_dataset("json", data_files=request.dataset_source, split=eval_split)
|
||||
elif request.dataset_source.endswith('.csv'):
|
||||
eval_dataset = load_dataset("csv", data_files=request.dataset_source, split=eval_split)
|
||||
else:
|
||||
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||
else:
|
||||
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||
elif eval_dataset_source:
|
||||
# Load eval dataset from a separate source
|
||||
eval_dataset = load_dataset(eval_dataset_source, split="train")
|
||||
else:
|
||||
# Auto-split from training set
|
||||
eval_split_ratio = float(extra.get("eval_split_ratio", "0.1"))
|
||||
split = dataset.train_test_split(test_size=eval_split_ratio)
|
||||
dataset = split["train"]
|
||||
eval_dataset = split["test"]
|
||||
|
||||
if eval_strategy == "no":
|
||||
eval_dataset = None
|
||||
|
||||
# Training config
|
||||
output_dir = request.output_dir or f"./output-{job.job_id}"
|
||||
num_epochs = request.num_epochs if request.num_epochs > 0 else 3
|
||||
batch_size = request.batch_size if request.batch_size > 0 else 2
|
||||
lr = request.learning_rate if request.learning_rate > 0 else 2e-4
|
||||
grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4
|
||||
warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5
|
||||
weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01
|
||||
max_steps = request.max_steps if request.max_steps > 0 else -1
|
||||
save_steps = request.save_steps if request.save_steps > 0 else 500
|
||||
seed = request.seed if request.seed > 0 else 3407
|
||||
optimizer = request.optimizer or "adamw_torch"
|
||||
|
||||
# Checkpoint save controls
|
||||
save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited
|
||||
save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no
|
||||
|
||||
# CPU vs GPU training args (can be overridden via extra_options)
|
||||
use_cpu = not torch.cuda.is_available()
|
||||
common_train_kwargs = {}
|
||||
if use_cpu:
|
||||
common_train_kwargs["use_cpu"] = True
|
||||
common_train_kwargs["fp16"] = False
|
||||
common_train_kwargs["bf16"] = False
|
||||
common_train_kwargs["gradient_checkpointing"] = False
|
||||
else:
|
||||
common_train_kwargs["bf16"] = True
|
||||
common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing
|
||||
|
||||
# Allow extra_options to override training kwargs
|
||||
for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"):
|
||||
if flag in extra:
|
||||
common_train_kwargs[flag] = extra[flag].lower() == "true"
|
||||
|
||||
# Create progress callback
|
||||
progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs)
|
||||
|
||||
# Build save kwargs (shared across all methods)
|
||||
_save_kwargs = {}
|
||||
if save_strategy == "steps" and save_steps > 0:
|
||||
_save_kwargs["save_steps"] = save_steps
|
||||
_save_kwargs["save_strategy"] = "steps"
|
||||
elif save_strategy == "epoch":
|
||||
_save_kwargs["save_strategy"] = "epoch"
|
||||
elif save_strategy == "no":
|
||||
_save_kwargs["save_strategy"] = "no"
|
||||
else:
|
||||
_save_kwargs["save_steps"] = save_steps
|
||||
_save_kwargs["save_strategy"] = "steps"
|
||||
if save_total_limit:
|
||||
_save_kwargs["save_total_limit"] = save_total_limit
|
||||
|
||||
# Eval kwargs
|
||||
_eval_kwargs = {}
|
||||
if eval_dataset is not None:
|
||||
_eval_kwargs["eval_strategy"] = eval_strategy
|
||||
_eval_kwargs["eval_steps"] = eval_steps
|
||||
|
||||
# Common training arguments shared by all methods
|
||||
_common_args = dict(
|
||||
output_dir=output_dir,
|
||||
num_train_epochs=num_epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
learning_rate=lr,
|
||||
gradient_accumulation_steps=grad_accum,
|
||||
warmup_steps=warmup_steps,
|
||||
weight_decay=weight_decay,
|
||||
max_steps=max_steps,
|
||||
seed=seed,
|
||||
optim=optimizer,
|
||||
logging_steps=1,
|
||||
report_to="none",
|
||||
**_save_kwargs,
|
||||
**common_train_kwargs,
|
||||
**_eval_kwargs,
|
||||
)
|
||||
|
||||
# Select trainer based on training method
|
||||
if training_method == "sft":
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
max_length = int(extra.get("max_seq_length", "512"))
|
||||
packing = extra.get("packing", "false").lower() == "true"
|
||||
|
||||
training_args = SFTConfig(
|
||||
max_length=max_length,
|
||||
packing=packing,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "dpo":
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
loss_type = extra.get("loss_type", "sigmoid")
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = DPOConfig(
|
||||
beta=beta,
|
||||
loss_type=loss_type,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "grpo":
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
num_generations = int(extra.get("num_generations", "4"))
|
||||
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||
|
||||
training_args = GRPOConfig(
|
||||
num_generations=num_generations,
|
||||
max_completion_length=max_completion_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
# GRPO requires reward functions passed via extra_options as a JSON list
|
||||
from reward_functions import build_reward_functions
|
||||
|
||||
reward_funcs = []
|
||||
if extra.get("reward_funcs"):
|
||||
reward_funcs = build_reward_functions(extra["reward_funcs"])
|
||||
|
||||
if not reward_funcs:
|
||||
raise ValueError(
|
||||
"GRPO requires at least one reward function. "
|
||||
"Specify reward_functions in the request or "
|
||||
"reward_funcs in extra_options."
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=reward_funcs,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "orpo":
|
||||
from trl import ORPOTrainer, ORPOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = ORPOConfig(
|
||||
beta=beta,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = ORPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "kto":
|
||||
from trl import KTOTrainer, KTOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = KTOConfig(
|
||||
beta=beta,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "rloo":
|
||||
from trl import RLOOTrainer, RLOOConfig
|
||||
|
||||
num_generations = int(extra.get("num_generations", "4"))
|
||||
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||
|
||||
training_args = RLOOConfig(
|
||||
num_generations=num_generations,
|
||||
max_new_tokens=max_completion_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = RLOOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "reward":
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = RewardConfig(
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported training method: {training_method}. "
|
||||
"Supported: sft, dpo, grpo, orpo, kto, rloo, reward")
|
||||
|
||||
job.trainer = trainer
|
||||
|
||||
# Start training
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="training", message="Training started",
|
||||
))
|
||||
|
||||
resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None
|
||||
trainer.train(resume_from_checkpoint=resume_ckpt)
|
||||
|
||||
# Save final model
|
||||
trainer.save_model(output_dir)
|
||||
if tokenizer:
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
job.completed = True
|
||||
# Sentinel to signal stream end
|
||||
job.progress_queue.put(None)
|
||||
|
||||
def FineTuneProgress(self, request, context):
|
||||
if self.active_job is None or self.active_job.job_id != request.job_id:
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details(f"Job {request.job_id} not found")
|
||||
return
|
||||
|
||||
job = self.active_job
|
||||
while True:
|
||||
try:
|
||||
update = job.progress_queue.get(timeout=1.0)
|
||||
if update is None:
|
||||
break
|
||||
yield update
|
||||
if update.status in ("completed", "failed", "stopped"):
|
||||
break
|
||||
except queue.Empty:
|
||||
if job.completed or job.stopped:
|
||||
break
|
||||
if not context.is_active():
|
||||
break
|
||||
continue
|
||||
|
||||
def StopFineTune(self, request, context):
|
||||
# Stopping is handled by killing the process from Go via ShutdownModel.
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def ListCheckpoints(self, request, context):
|
||||
output_dir = request.output_dir
|
||||
if not os.path.isdir(output_dir):
|
||||
return backend_pb2.ListCheckpointsResponse(checkpoints=[])
|
||||
|
||||
checkpoints = []
|
||||
for entry in sorted(os.listdir(output_dir)):
|
||||
if entry.startswith("checkpoint-"):
|
||||
ckpt_path = os.path.join(output_dir, entry)
|
||||
if not os.path.isdir(ckpt_path):
|
||||
continue
|
||||
step = 0
|
||||
try:
|
||||
step = int(entry.split("-")[1])
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
# Try to read trainer_state.json for metadata
|
||||
loss = 0.0
|
||||
epoch = 0.0
|
||||
state_file = os.path.join(ckpt_path, "trainer_state.json")
|
||||
if os.path.exists(state_file):
|
||||
try:
|
||||
with open(state_file) as f:
|
||||
state = json.load(f)
|
||||
if state.get("log_history"):
|
||||
last_log = state["log_history"][-1]
|
||||
loss = last_log.get("loss", 0.0)
|
||||
epoch = last_log.get("epoch", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
created_at = time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ",
|
||||
time.gmtime(os.path.getmtime(ckpt_path)),
|
||||
)
|
||||
|
||||
checkpoints.append(backend_pb2.CheckpointInfo(
|
||||
path=ckpt_path,
|
||||
step=step,
|
||||
epoch=float(epoch),
|
||||
loss=float(loss),
|
||||
created_at=created_at,
|
||||
))
|
||||
|
||||
return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints)
|
||||
|
||||
def ExportModel(self, request, context):
|
||||
export_format = request.export_format or "lora"
|
||||
output_path = request.output_path
|
||||
checkpoint_path = request.checkpoint_path
|
||||
|
||||
# Extract HF token for gated model access
|
||||
extra = dict(request.extra_options) if request.extra_options else {}
|
||||
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
|
||||
if not checkpoint_path or not os.path.isdir(checkpoint_path):
|
||||
return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
try:
|
||||
if export_format == "lora":
|
||||
# Just copy the adapter files
|
||||
import shutil
|
||||
for f in os.listdir(checkpoint_path):
|
||||
src = os.path.join(checkpoint_path, f)
|
||||
dst = os.path.join(output_path, f)
|
||||
if os.path.isfile(src):
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
elif export_format in ("merged_16bit", "merged_4bit"):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import PeftModel
|
||||
|
||||
base_model_name = request.model
|
||||
if not base_model_name:
|
||||
return backend_pb2.Result(success=False, message="Base model name required for merge export")
|
||||
|
||||
dtype = torch.float16 if export_format == "merged_16bit" else torch.float32
|
||||
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token)
|
||||
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||
merged = model.merge_and_unload()
|
||||
merged.save_pretrained(output_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||
tokenizer.save_pretrained(output_path)
|
||||
|
||||
elif export_format == "gguf":
|
||||
import torch
|
||||
import subprocess
|
||||
import shutil
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import PeftModel
|
||||
|
||||
base_model_name = request.model
|
||||
if not base_model_name:
|
||||
return backend_pb2.Result(success=False, message="Base model name required for GGUF export")
|
||||
|
||||
# Step 1: Merge LoRA into base model
|
||||
merge_dir = os.path.join(output_path, "_hf_merged")
|
||||
os.makedirs(merge_dir, exist_ok=True)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token)
|
||||
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||
merged = model.merge_and_unload()
|
||||
merged.save_pretrained(merge_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||
tokenizer.save_pretrained(merge_dir)
|
||||
|
||||
# Ensure tokenizer.model (SentencePiece) is present in merge_dir.
|
||||
# Gemma models need this file for GGUF conversion to use the
|
||||
# SentencePiece path; without it, the script falls back to BPE
|
||||
# handling which fails on unrecognized pre-tokenizer hashes.
|
||||
sp_model_path = os.path.join(merge_dir, "tokenizer.model")
|
||||
if not os.path.exists(sp_model_path):
|
||||
sp_copied = False
|
||||
# Method 1: Load the slow tokenizer which keeps the SP model file
|
||||
try:
|
||||
slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token)
|
||||
if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file):
|
||||
import shutil as _shutil
|
||||
_shutil.copy2(slow_tok.vocab_file, sp_model_path)
|
||||
sp_copied = True
|
||||
print(f"Copied tokenizer.model from slow tokenizer cache")
|
||||
except Exception as e:
|
||||
print(f"Slow tokenizer method failed: {e}")
|
||||
# Method 2: Download from HF hub
|
||||
if not sp_copied:
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token)
|
||||
import shutil as _shutil
|
||||
_shutil.copy2(cached_sp, sp_model_path)
|
||||
sp_copied = True
|
||||
print(f"Copied tokenizer.model from HF hub")
|
||||
except Exception as e:
|
||||
print(f"HF hub download method failed: {e}")
|
||||
if not sp_copied:
|
||||
print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. "
|
||||
"GGUF conversion may fail for SentencePiece models.")
|
||||
|
||||
# Free GPU memory before conversion
|
||||
del merged, model, base_model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Step 2: Convert to GGUF using convert_hf_to_gguf.py
|
||||
quant = request.quantization_method or "auto"
|
||||
outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"}
|
||||
outtype = outtype_map.get(quant, "f16")
|
||||
|
||||
gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf"
|
||||
gguf_path = os.path.join(output_path, gguf_filename)
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py")
|
||||
if not os.path.exists(convert_script):
|
||||
return backend_pb2.Result(success=False,
|
||||
message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.")
|
||||
|
||||
# Log merge_dir contents for debugging conversion issues
|
||||
merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else []
|
||||
print(f"Merge dir contents: {merge_files}", flush=True)
|
||||
|
||||
env = os.environ.copy()
|
||||
env["NO_LOCAL_GGUF"] = "1"
|
||||
cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path]
|
||||
conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
|
||||
if conv_result.returncode != 0:
|
||||
diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}"
|
||||
return backend_pb2.Result(success=False,
|
||||
message=f"GGUF conversion failed: {diag}")
|
||||
|
||||
# Clean up intermediate merged model
|
||||
shutil.rmtree(merge_dir, ignore_errors=True)
|
||||
else:
|
||||
return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}")
|
||||
|
||||
except Exception as e:
|
||||
if _is_gated_repo_error(e):
|
||||
return backend_pb2.Result(success=False,
|
||||
message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||
return backend_pb2.Result(success=False, message=f"Export failed: {e}")
|
||||
|
||||
return backend_pb2.Result(success=True, message=f"Model exported to {output_path}")
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True)
|
||||
|
||||
# Handle graceful shutdown
|
||||
def stop(signum, frame):
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, stop)
|
||||
signal.signal(signal.SIGINT, stop)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||
args = parser.parse_args()
|
||||
serve(args.addr)
|
||||
37
backend/python/trl/install.sh
Normal file
37
backend/python/trl/install.sh
Normal file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
installRequirements
|
||||
|
||||
# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version
|
||||
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
||||
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
||||
if [ ! -f "${CONVERT_SCRIPT}" ]; then
|
||||
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
curl -L --fail --retry 3 \
|
||||
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
|
||||
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available."
|
||||
fi
|
||||
|
||||
# Install gguf package from the same llama.cpp commit to keep them in sync
|
||||
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
|
||||
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
if [ "x${USE_PIP:-}" == "xtrue" ]; then
|
||||
pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
pip install "gguf>=0.16.0"
|
||||
}
|
||||
else
|
||||
uv pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
uv pip install "gguf>=0.16.0"
|
||||
}
|
||||
fi
|
||||
9
backend/python/trl/requirements-cpu.txt
Normal file
9
backend/python/trl/requirements-cpu.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.10.0
|
||||
trl
|
||||
peft
|
||||
datasets>=3.0.0
|
||||
transformers>=4.56.2
|
||||
accelerate>=1.4.0
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
9
backend/python/trl/requirements-cublas12.txt
Normal file
9
backend/python/trl/requirements-cublas12.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
torch==2.10.0
|
||||
trl
|
||||
peft
|
||||
datasets>=3.0.0
|
||||
transformers>=4.56.2
|
||||
accelerate>=1.4.0
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
bitsandbytes
|
||||
9
backend/python/trl/requirements-cublas13.txt
Normal file
9
backend/python/trl/requirements-cublas13.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
torch==2.10.0
|
||||
trl
|
||||
peft
|
||||
datasets>=3.0.0
|
||||
transformers>=4.56.2
|
||||
accelerate>=1.4.0
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
bitsandbytes
|
||||
3
backend/python/trl/requirements.txt
Normal file
3
backend/python/trl/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
236
backend/python/trl/reward_functions.py
Normal file
236
backend/python/trl/reward_functions.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
Built-in reward functions and inline function compiler for GRPO training.
|
||||
|
||||
All reward functions follow TRL's signature: (completions, **kwargs) -> list[float]
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import math
|
||||
import string
|
||||
import functools
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in reward functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Checks for <think>...</think> followed by an answer. Returns 1.0 or 0.0."""
|
||||
pattern = re.compile(r"<think>.*?</think>\s*\S", re.DOTALL)
|
||||
return [1.0 if pattern.search(c) else 0.0 for c in completions]
|
||||
|
||||
|
||||
def reasoning_accuracy_reward(completions, **kwargs):
|
||||
"""Extracts <answer>...</answer> content and compares to the expected answer."""
|
||||
answers = kwargs.get("answer", [])
|
||||
if not answers:
|
||||
return [0.0] * len(completions)
|
||||
|
||||
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
|
||||
scores = []
|
||||
for i, c in enumerate(completions):
|
||||
expected = answers[i] if i < len(answers) else ""
|
||||
match = pattern.search(c)
|
||||
if match:
|
||||
extracted = match.group(1).strip()
|
||||
scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0)
|
||||
else:
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
|
||||
def length_reward(completions, target_length=200, **kwargs):
|
||||
"""Score based on proximity to target_length. Returns [0, 1]."""
|
||||
scores = []
|
||||
for c in completions:
|
||||
length = len(c)
|
||||
if target_length <= 0:
|
||||
scores.append(0.0)
|
||||
else:
|
||||
diff = abs(length - target_length) / target_length
|
||||
scores.append(max(0.0, 1.0 - diff))
|
||||
return scores
|
||||
|
||||
|
||||
def xml_tag_reward(completions, **kwargs):
|
||||
"""Scores properly opened/closed XML tags (<think>, <answer>)."""
|
||||
tags = ["think", "answer"]
|
||||
scores = []
|
||||
for c in completions:
|
||||
tag_score = 0.0
|
||||
for tag in tags:
|
||||
if f"<{tag}>" in c and f"</{tag}>" in c:
|
||||
tag_score += 0.5
|
||||
scores.append(min(tag_score, 1.0))
|
||||
return scores
|
||||
|
||||
|
||||
def no_repetition_reward(completions, n=4, **kwargs):
|
||||
"""Penalizes n-gram repetition. Returns [0, 1]."""
|
||||
scores = []
|
||||
for c in completions:
|
||||
words = c.split()
|
||||
if len(words) < n:
|
||||
scores.append(1.0)
|
||||
continue
|
||||
ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
|
||||
unique = len(set(ngrams))
|
||||
total = len(ngrams)
|
||||
scores.append(unique / total if total > 0 else 1.0)
|
||||
return scores
|
||||
|
||||
|
||||
def code_execution_reward(completions, **kwargs):
|
||||
"""Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0."""
|
||||
pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL)
|
||||
scores = []
|
||||
for c in completions:
|
||||
match = pattern.search(c)
|
||||
if not match:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
code = match.group(1)
|
||||
try:
|
||||
compile(code, "<inline>", "exec")
|
||||
scores.append(1.0)
|
||||
except SyntaxError:
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BUILTIN_REGISTRY = {
|
||||
"format_reward": format_reward,
|
||||
"reasoning_accuracy_reward": reasoning_accuracy_reward,
|
||||
"length_reward": length_reward,
|
||||
"xml_tag_reward": xml_tag_reward,
|
||||
"no_repetition_reward": no_repetition_reward,
|
||||
"code_execution_reward": code_execution_reward,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inline function compiler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SAFE_BUILTINS = {
|
||||
"len": len, "int": int, "float": float, "str": str, "bool": bool,
|
||||
"list": list, "dict": dict, "tuple": tuple, "set": set,
|
||||
"range": range, "enumerate": enumerate, "zip": zip,
|
||||
"map": map, "filter": filter, "sorted": sorted,
|
||||
"min": min, "max": max, "sum": sum, "abs": abs, "round": round,
|
||||
"any": any, "all": all, "isinstance": isinstance, "type": type,
|
||||
"print": print, "True": True, "False": False, "None": None,
|
||||
"ValueError": ValueError, "TypeError": TypeError,
|
||||
"KeyError": KeyError, "IndexError": IndexError,
|
||||
}
|
||||
|
||||
|
||||
def compile_inline_reward(name, code):
|
||||
"""Compile user-provided code into a reward function.
|
||||
|
||||
The code should be the body of a function that receives
|
||||
`completions` (list[str]) and `**kwargs`, and returns list[float].
|
||||
|
||||
Available modules: re, math, json, string.
|
||||
"""
|
||||
func_source = (
|
||||
f"def _user_reward_{name}(completions, **kwargs):\n"
|
||||
+ "\n".join(f" {line}" for line in code.splitlines())
|
||||
)
|
||||
|
||||
restricted_globals = {
|
||||
"__builtins__": _SAFE_BUILTINS,
|
||||
"re": re,
|
||||
"math": math,
|
||||
"json": json,
|
||||
"string": string,
|
||||
}
|
||||
|
||||
try:
|
||||
compiled = compile(func_source, f"<inline-reward-{name}>", "exec")
|
||||
except SyntaxError as e:
|
||||
raise ValueError(f"Syntax error in inline reward function '{name}': {e}")
|
||||
|
||||
exec(compiled, restricted_globals)
|
||||
func = restricted_globals[f"_user_reward_{name}"]
|
||||
|
||||
# Validate with a quick smoke test
|
||||
try:
|
||||
result = func(["test"], answer=["test"])
|
||||
if not isinstance(result, list):
|
||||
raise ValueError(
|
||||
f"Inline reward function '{name}' must return a list, got {type(result).__name__}"
|
||||
)
|
||||
except Exception as e:
|
||||
if "must return a list" in str(e):
|
||||
raise
|
||||
# Other errors during smoke test are acceptable (e.g. missing kwargs)
|
||||
pass
|
||||
|
||||
return func
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_reward_functions(specs_json):
|
||||
"""Parse a JSON list of reward function specs and return a list of callables.
|
||||
|
||||
Each spec is a dict with:
|
||||
- type: "builtin" or "inline"
|
||||
- name: function name
|
||||
- code: (inline only) Python function body
|
||||
- params: (optional) dict of string params applied via functools.partial
|
||||
"""
|
||||
if isinstance(specs_json, str):
|
||||
specs = json.loads(specs_json)
|
||||
else:
|
||||
specs = specs_json
|
||||
|
||||
if not isinstance(specs, list):
|
||||
raise ValueError("reward_funcs must be a JSON array of reward function specs")
|
||||
|
||||
reward_funcs = []
|
||||
for spec in specs:
|
||||
spec_type = spec.get("type", "builtin")
|
||||
name = spec.get("name", "")
|
||||
params = spec.get("params", {})
|
||||
|
||||
if spec_type == "builtin":
|
||||
if name not in BUILTIN_REGISTRY:
|
||||
available = ", ".join(sorted(BUILTIN_REGISTRY.keys()))
|
||||
raise ValueError(
|
||||
f"Unknown builtin reward function '{name}'. Available: {available}"
|
||||
)
|
||||
func = BUILTIN_REGISTRY[name]
|
||||
if params:
|
||||
# Convert string params to appropriate types
|
||||
typed_params = {}
|
||||
for k, v in params.items():
|
||||
try:
|
||||
typed_params[k] = int(v)
|
||||
except (ValueError, TypeError):
|
||||
try:
|
||||
typed_params[k] = float(v)
|
||||
except (ValueError, TypeError):
|
||||
typed_params[k] = v
|
||||
func = functools.partial(func, **typed_params)
|
||||
reward_funcs.append(func)
|
||||
|
||||
elif spec_type == "inline":
|
||||
code = spec.get("code", "")
|
||||
if not code.strip():
|
||||
raise ValueError(f"Inline reward function '{name}' has no code")
|
||||
func = compile_inline_reward(name, code)
|
||||
reward_funcs.append(func)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'")
|
||||
|
||||
return reward_funcs
|
||||
10
backend/python/trl/run.sh
Normal file
10
backend/python/trl/run.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
58
backend/python/trl/test.py
Normal file
58
backend/python/trl/test.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Test script for the TRL fine-tuning gRPC backend.
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""Tests for the TRL fine-tuning gRPC service."""
|
||||
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", "localhost:50051"]
|
||||
)
|
||||
time.sleep(10)
|
||||
|
||||
def tearDown(self):
|
||||
self.service.kill()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
"""Test that the server starts and responds to health checks."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_list_checkpoints_empty(self):
|
||||
"""Test listing checkpoints on a non-existent directory."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.ListCheckpoints(
|
||||
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
|
||||
)
|
||||
self.assertEqual(len(response.checkpoints), 0)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("ListCheckpoints service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
11
backend/python/trl/test.sh
Normal file
11
backend/python/trl/test.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -121,6 +121,9 @@ type RunCMD struct {
|
||||
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
|
||||
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
|
||||
|
||||
// Fine-tuning
|
||||
EnableFineTuning bool `env:"LOCALAI_ENABLE_FINETUNING" default:"false" help:"Enable fine-tuning support" group:"finetuning"`
|
||||
|
||||
// Authentication
|
||||
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
|
||||
AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"`
|
||||
@@ -326,6 +329,11 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.WithAgentHubURL(r.AgentHubURL))
|
||||
}
|
||||
|
||||
// Fine-tuning
|
||||
if r.EnableFineTuning {
|
||||
opts = append(opts, config.EnableFineTuning)
|
||||
}
|
||||
|
||||
// Authentication
|
||||
authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != ""
|
||||
if authEnabled {
|
||||
|
||||
@@ -97,6 +97,9 @@ type ApplicationConfig struct {
|
||||
// Agent Pool (LocalAGI integration)
|
||||
AgentPool AgentPoolConfig
|
||||
|
||||
// Fine-tuning
|
||||
FineTuning FineTuningConfig
|
||||
|
||||
// Authentication & Authorization
|
||||
Auth AuthConfig
|
||||
}
|
||||
@@ -142,6 +145,11 @@ type AgentPoolConfig struct {
|
||||
AgentHubURL string // default: "https://agenthub.localai.io"
|
||||
}
|
||||
|
||||
// FineTuningConfig holds configuration for fine-tuning support.
|
||||
type FineTuningConfig struct {
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
type AppOption func(*ApplicationConfig)
|
||||
|
||||
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
@@ -733,6 +741,12 @@ func WithAgentHubURL(url string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// Fine-tuning options
|
||||
|
||||
var EnableFineTuning = func(o *ApplicationConfig) {
|
||||
o.FineTuning.Enabled = true
|
||||
}
|
||||
|
||||
// Auth options
|
||||
|
||||
func WithAuthEnabled(enabled bool) AppOption {
|
||||
|
||||
205
core/gallery/importers/local.go
Normal file
205
core/gallery/importers/local.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// ImportLocalPath scans a local directory for exported model files and produces
|
||||
// a config.ModelConfig with the correct backend, model path, and options.
|
||||
// Paths in the returned config are relative to modelsPath when possible so that
|
||||
// the YAML config remains portable.
|
||||
//
|
||||
// Detection order:
|
||||
// 1. GGUF files (*.gguf) — uses llama-cpp backend
|
||||
// 2. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter
|
||||
// 3. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend
|
||||
func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||
// Make paths relative to the models directory (parent of dirPath)
|
||||
// so config YAML stays portable.
|
||||
modelsDir := filepath.Dir(dirPath)
|
||||
relPath := func(absPath string) string {
|
||||
if rel, err := filepath.Rel(modelsDir, absPath); err == nil {
|
||||
return rel
|
||||
}
|
||||
return absPath
|
||||
}
|
||||
|
||||
// 1. GGUF: check dirPath and dirPath_gguf/ (Unsloth convention)
|
||||
ggufFile := findGGUF(dirPath)
|
||||
if ggufFile == "" {
|
||||
ggufSubdir := dirPath + "_gguf"
|
||||
ggufFile = findGGUF(ggufSubdir)
|
||||
}
|
||||
if ggufFile != "" {
|
||||
xlog.Info("ImportLocalPath: detected GGUF model", "path", ggufFile)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Options: []string{"use_jinja:true"},
|
||||
}
|
||||
cfg.Model = relPath(ggufFile)
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.Description = buildDescription(dirPath, "GGUF")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// 2. LoRA adapter: look for adapter_config.json
|
||||
|
||||
adapterConfigPath := filepath.Join(dirPath, "adapter_config.json")
|
||||
if fileExists(adapterConfigPath) {
|
||||
xlog.Info("ImportLocalPath: detected LoRA adapter", "path", dirPath)
|
||||
baseModel := readBaseModel(dirPath)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
}
|
||||
cfg.Model = baseModel
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
|
||||
cfg.Description = buildDescription(dirPath, "LoRA adapter")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Also check for adapter_model.safetensors or adapter_model.bin without adapter_config.json
|
||||
if fileExists(filepath.Join(dirPath, "adapter_model.safetensors")) || fileExists(filepath.Join(dirPath, "adapter_model.bin")) {
|
||||
xlog.Info("ImportLocalPath: detected LoRA adapter (by model files)", "path", dirPath)
|
||||
baseModel := readBaseModel(dirPath)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
}
|
||||
cfg.Model = baseModel
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
|
||||
cfg.Description = buildDescription(dirPath, "LoRA adapter")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// 3. Merged model: *.safetensors or pytorch_model*.bin + config.json
|
||||
if fileExists(filepath.Join(dirPath, "config.json")) && (hasFileWithSuffix(dirPath, ".safetensors") || hasFileWithPrefix(dirPath, "pytorch_model")) {
|
||||
xlog.Info("ImportLocalPath: detected merged model", "path", dirPath)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
}
|
||||
cfg.Model = relPath(dirPath)
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.Description = buildDescription(dirPath, "merged model")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("could not detect model format in directory %s", dirPath)
|
||||
}
|
||||
|
||||
// findGGUF returns the path to the first .gguf file found in dir, or "".
|
||||
func findGGUF(dir string) string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), ".gguf") {
|
||||
return filepath.Join(dir, e.Name())
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// readBaseModel reads the base model name from adapter_config.json or export_metadata.json.
|
||||
func readBaseModel(dirPath string) string {
|
||||
// Try adapter_config.json → base_model_name_or_path (TRL writes this)
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
|
||||
var ac map[string]any
|
||||
if json.Unmarshal(data, &ac) == nil {
|
||||
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
|
||||
return bm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try export_metadata.json → base_model (Unsloth writes this)
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
|
||||
var meta map[string]any
|
||||
if json.Unmarshal(data, &meta) == nil {
|
||||
if bm, ok := meta["base_model"].(string); ok && bm != "" {
|
||||
return bm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// buildDescription creates a human-readable description using available metadata.
|
||||
func buildDescription(dirPath, formatLabel string) string {
|
||||
base := ""
|
||||
|
||||
// Try adapter_config.json
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
|
||||
var ac map[string]any
|
||||
if json.Unmarshal(data, &ac) == nil {
|
||||
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
|
||||
base = bm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try export_metadata.json
|
||||
if base == "" {
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
|
||||
var meta map[string]any
|
||||
if json.Unmarshal(data, &meta) == nil {
|
||||
if bm, ok := meta["base_model"].(string); ok && bm != "" {
|
||||
base = bm
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if base != "" {
|
||||
return fmt.Sprintf("Fine-tuned from %s (%s)", base, formatLabel)
|
||||
}
|
||||
return fmt.Sprintf("Fine-tuned model (%s)", formatLabel)
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
info, err := os.Stat(path)
|
||||
return err == nil && !info.IsDir()
|
||||
}
|
||||
|
||||
func hasFileWithSuffix(dir, suffix string) bool {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasFileWithPrefix(dir, prefix string) bool {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() && strings.HasPrefix(e.Name(), prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
148
core/gallery/importers/local_test.go
Normal file
148
core/gallery/importers/local_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
)
|
||||
|
||||
var _ = Describe("ImportLocalPath", func() {
|
||||
var tmpDir string
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpDir, err = os.MkdirTemp("", "importers-local-test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
Context("GGUF detection", func() {
|
||||
It("detects a GGUF file in the directory", func() {
|
||||
modelDir := filepath.Join(tmpDir, "my-model")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "model-q4_k_m.gguf"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("llama-cpp"))
|
||||
Expect(cfg.Model).To(ContainSubstring(".gguf"))
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
Expect(cfg.KnownUsecaseStrings).To(ContainElement("chat"))
|
||||
Expect(cfg.Options).To(ContainElement("use_jinja:true"))
|
||||
})
|
||||
|
||||
It("detects GGUF in _gguf subdirectory", func() {
|
||||
modelDir := filepath.Join(tmpDir, "my-model")
|
||||
ggufDir := modelDir + "_gguf"
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.MkdirAll(ggufDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(ggufDir, "model.gguf"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("llama-cpp"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("LoRA adapter detection", func() {
|
||||
It("detects LoRA adapter via adapter_config.json", func() {
|
||||
modelDir := filepath.Join(tmpDir, "lora-model")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
adapterConfig := map[string]any{
|
||||
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
||||
"peft_type": "LORA",
|
||||
}
|
||||
data, _ := json.Marshal(adapterConfig)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "lora-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("transformers"))
|
||||
Expect(cfg.Model).To(Equal("meta-llama/Llama-2-7b-hf"))
|
||||
Expect(cfg.LLMConfig.LoraAdapter).To(Equal("lora-model"))
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
})
|
||||
|
||||
It("reads base model from export_metadata.json as fallback", func() {
|
||||
modelDir := filepath.Join(tmpDir, "lora-unsloth")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
adapterConfig := map[string]any{"peft_type": "LORA"}
|
||||
data, _ := json.Marshal(adapterConfig)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||
|
||||
metadata := map[string]any{"base_model": "unsloth/tinyllama-bnb-4bit"}
|
||||
data, _ = json.Marshal(metadata)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "export_metadata.json"), data, 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "lora-unsloth")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Model).To(Equal("unsloth/tinyllama-bnb-4bit"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Merged model detection", func() {
|
||||
It("detects merged model with safetensors + config.json", func() {
|
||||
modelDir := filepath.Join(tmpDir, "merged")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "merged")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("transformers"))
|
||||
Expect(cfg.Model).To(Equal("merged"))
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
})
|
||||
|
||||
It("detects merged model with pytorch_model files", func() {
|
||||
modelDir := filepath.Join(tmpDir, "merged-pt")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "pytorch_model-00001-of-00002.bin"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "merged-pt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("transformers"))
|
||||
Expect(cfg.Model).To(Equal("merged-pt"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("fallback", func() {
|
||||
It("returns error for empty directory", func() {
|
||||
modelDir := filepath.Join(tmpDir, "empty")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
_, err := importers.ImportLocalPath(modelDir, "empty")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("could not detect model format"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("description", func() {
|
||||
It("includes base model name in description", func() {
|
||||
modelDir := filepath.Join(tmpDir, "desc-test")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
adapterConfig := map[string]any{
|
||||
"base_model_name_or_path": "TinyLlama/TinyLlama-1.1B",
|
||||
}
|
||||
data, _ := json.Marshal(adapterConfig)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "desc-test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Description).To(ContainSubstring("TinyLlama/TinyLlama-1.1B"))
|
||||
Expect(cfg.Description).To(ContainSubstring("Fine-tuned from"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -302,6 +302,17 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
|
||||
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw, mcpMw)
|
||||
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
|
||||
// Fine-tuning routes
|
||||
if application.ApplicationConfig().FineTuning.Enabled {
|
||||
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
|
||||
ftService := services.NewFineTuneService(
|
||||
application.ApplicationConfig(),
|
||||
application.ModelLoader(),
|
||||
application.ModelConfigLoader(),
|
||||
)
|
||||
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
|
||||
}
|
||||
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
||||
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
||||
|
||||
@@ -85,6 +85,18 @@ var RouteFeatureRegistry = []RouteFeature{
|
||||
{"POST", "/stores/delete", FeatureStores},
|
||||
{"POST", "/stores/get", FeatureStores},
|
||||
{"POST", "/stores/find", FeatureStores},
|
||||
|
||||
// Fine-tuning
|
||||
{"POST", "/api/fine-tuning/jobs", FeatureFineTuning},
|
||||
{"GET", "/api/fine-tuning/jobs", FeatureFineTuning},
|
||||
{"GET", "/api/fine-tuning/jobs/:id", FeatureFineTuning},
|
||||
{"POST", "/api/fine-tuning/jobs/:id/stop", FeatureFineTuning},
|
||||
{"DELETE", "/api/fine-tuning/jobs/:id", FeatureFineTuning},
|
||||
{"GET", "/api/fine-tuning/jobs/:id/progress", FeatureFineTuning},
|
||||
{"GET", "/api/fine-tuning/jobs/:id/checkpoints", FeatureFineTuning},
|
||||
{"POST", "/api/fine-tuning/jobs/:id/export", FeatureFineTuning},
|
||||
{"GET", "/api/fine-tuning/jobs/:id/download", FeatureFineTuning},
|
||||
{"POST", "/api/fine-tuning/datasets", FeatureFineTuning},
|
||||
}
|
||||
|
||||
// FeatureMeta describes a feature for the admin API/UI.
|
||||
@@ -104,6 +116,13 @@ func AgentFeatureMetas() []FeatureMeta {
|
||||
}
|
||||
}
|
||||
|
||||
// GeneralFeatureMetas returns metadata for general features.
|
||||
func GeneralFeatureMetas() []FeatureMeta {
|
||||
return []FeatureMeta{
|
||||
{FeatureFineTuning, "Fine-Tuning", false},
|
||||
}
|
||||
}
|
||||
|
||||
// APIFeatureMetas returns metadata for API endpoint features.
|
||||
func APIFeatureMetas() []FeatureMeta {
|
||||
return []FeatureMeta{
|
||||
|
||||
@@ -32,6 +32,9 @@ const (
|
||||
FeatureCollections = "collections"
|
||||
FeatureMCPJobs = "mcp_jobs"
|
||||
|
||||
// General features (default OFF for new users)
|
||||
FeatureFineTuning = "fine_tuning"
|
||||
|
||||
// API features (default ON for new users)
|
||||
FeatureChat = "chat"
|
||||
FeatureImages = "images"
|
||||
@@ -52,6 +55,9 @@ const (
|
||||
// AgentFeatures lists agent-related features (default OFF).
|
||||
var AgentFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs}
|
||||
|
||||
// GeneralFeatures lists general features (default OFF).
|
||||
var GeneralFeatures = []string{FeatureFineTuning}
|
||||
|
||||
// APIFeatures lists API endpoint features (default ON).
|
||||
var APIFeatures = []string{
|
||||
FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription,
|
||||
@@ -60,7 +66,7 @@ var APIFeatures = []string{
|
||||
}
|
||||
|
||||
// AllFeatures lists all known features (used by UI and validation).
|
||||
var AllFeatures = append(append([]string{}, AgentFeatures...), APIFeatures...)
|
||||
var AllFeatures = append(append(append([]string{}, AgentFeatures...), GeneralFeatures...), APIFeatures...)
|
||||
|
||||
// defaultOnFeatures is the set of features that default to ON when absent from a user's permission map.
|
||||
var defaultOnFeatures = func() map[string]bool {
|
||||
|
||||
@@ -80,13 +80,14 @@ func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
defer src.Close()
|
||||
if err := svc.UploadToCollectionForUser(userID, name, file.Filename, src); err != nil {
|
||||
key, err := svc.UploadToCollectionForUser(userID, name, file.Filename, src)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename})
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename, "key": key})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
362
core/http/endpoints/localai/finetune.go
Normal file
362
core/http/endpoints/localai/finetune.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// StartFineTuneJobEndpoint starts a new fine-tuning job.
|
||||
func StartFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
|
||||
var req schema.FineTuneJobRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "Invalid request: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "model is required",
|
||||
})
|
||||
}
|
||||
if req.DatasetSource == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "dataset_source is required",
|
||||
})
|
||||
}
|
||||
|
||||
resp, err := ftService.StartJob(c.Request().Context(), userID, req)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusCreated, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// ListFineTuneJobsEndpoint lists fine-tuning jobs for the current user.
|
||||
func ListFineTuneJobsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobs := ftService.ListJobs(userID)
|
||||
if jobs == nil {
|
||||
jobs = []*schema.FineTuneJob{}
|
||||
}
|
||||
return c.JSON(http.StatusOK, jobs)
|
||||
}
|
||||
}
|
||||
|
||||
// GetFineTuneJobEndpoint gets a specific fine-tuning job.
|
||||
func GetFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
job, err := ftService.GetJob(userID, jobID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, job)
|
||||
}
|
||||
}
|
||||
|
||||
// StopFineTuneJobEndpoint stops a running fine-tuning job.
|
||||
func StopFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
// Check for save_checkpoint query param
|
||||
saveCheckpoint := c.QueryParam("save_checkpoint") == "true"
|
||||
|
||||
err := ftService.StopJob(c.Request().Context(), userID, jobID, saveCheckpoint)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]string{
|
||||
"status": "stopped",
|
||||
"message": "Fine-tuning job stopped",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteFineTuneJobEndpoint deletes a fine-tuning job and its data.
|
||||
func DeleteFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
err := ftService.DeleteJob(userID, jobID)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
} else if strings.Contains(err.Error(), "cannot delete") {
|
||||
status = http.StatusConflict
|
||||
}
|
||||
return c.JSON(status, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]string{
|
||||
"status": "deleted",
|
||||
"message": "Fine-tuning job deleted",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// FineTuneProgressEndpoint streams progress updates via SSE.
|
||||
func FineTuneProgressEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
// Set SSE headers
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
|
||||
err := ftService.StreamProgress(c.Request().Context(), userID, jobID, func(event *schema.FineTuneProgressEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(c.Response(), "data: %s\n\n", data)
|
||||
c.Response().Flush()
|
||||
})
|
||||
if err != nil {
|
||||
// If headers already sent, we can't send a JSON error
|
||||
fmt.Fprintf(c.Response(), "data: {\"status\":\"error\",\"message\":%q}\n\n", err.Error())
|
||||
c.Response().Flush()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ListCheckpointsEndpoint lists checkpoints for a job.
|
||||
func ListCheckpointsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
checkpoints, err := ftService.ListCheckpoints(c.Request().Context(), userID, jobID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"checkpoints": checkpoints,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ExportModelEndpoint exports a model from a checkpoint.
|
||||
func ExportModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
var req schema.ExportRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "Invalid request: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
modelName, err := ftService.ExportModel(c.Request().Context(), userID, jobID, req)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusAccepted, map[string]string{
|
||||
"status": "exporting",
|
||||
"message": "Export started for model '" + modelName + "'",
|
||||
"model_name": modelName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// DownloadExportedModelEndpoint streams the exported model directory as a tar.gz archive.
|
||||
func DownloadExportedModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
modelDir, modelName, err := ftService.GetExportedModelPath(userID, jobID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
c.Response().Header().Set("Content-Type", "application/gzip")
|
||||
c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s.tar.gz"`, modelName))
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
|
||||
gw := gzip.NewWriter(c.Response())
|
||||
defer gw.Close()
|
||||
|
||||
tw := tar.NewWriter(gw)
|
||||
defer tw.Close()
|
||||
|
||||
err = filepath.Walk(modelDir, func(path string, info os.FileInfo, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(modelDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Name = filepath.Join(modelName, relPath)
|
||||
|
||||
if err := tw.WriteHeader(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(tw, f)
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// Headers already sent, can't return JSON error
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ListFineTuneBackendsEndpoint returns installed backends tagged with "fine-tuning".
|
||||
func ListFineTuneBackendsEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": "failed to list backends: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
type backendInfo struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
var result []backendInfo
|
||||
for _, b := range backends {
|
||||
if !b.Installed {
|
||||
continue
|
||||
}
|
||||
hasTag := false
|
||||
for _, t := range b.Tags {
|
||||
if strings.EqualFold(t, "fine-tuning") {
|
||||
hasTag = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasTag {
|
||||
continue
|
||||
}
|
||||
name := b.Name
|
||||
if b.Alias != "" {
|
||||
name = b.Alias
|
||||
}
|
||||
result = append(result, backendInfo{
|
||||
Name: name,
|
||||
Description: b.Description,
|
||||
Tags: b.Tags,
|
||||
})
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
result = []backendInfo{}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
|
||||
// UploadDatasetEndpoint handles dataset file upload.
|
||||
func UploadDatasetEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "file is required",
|
||||
})
|
||||
}
|
||||
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": "failed to open file",
|
||||
})
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
data, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": "failed to read file",
|
||||
})
|
||||
}
|
||||
|
||||
path, err := ftService.UploadDataset(file.Filename, data)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]string{
|
||||
"path": path,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -208,6 +208,32 @@
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sidebar-section-toggle {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
width: 100%;
|
||||
background: none;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
transition: color var(--duration-fast);
|
||||
}
|
||||
|
||||
.sidebar-section-toggle:hover {
|
||||
color: var(--color-text-secondary);
|
||||
}
|
||||
|
||||
.sidebar-section-chevron {
|
||||
font-size: 0.5rem;
|
||||
transition: transform var(--duration-fast);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.sidebar-section-toggle.open .sidebar-section-chevron {
|
||||
transform: rotate(90deg);
|
||||
}
|
||||
|
||||
.nav-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
@@ -392,6 +418,10 @@
|
||||
display: none;
|
||||
}
|
||||
|
||||
.sidebar.collapsed .sidebar-section-chevron {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.sidebar.collapsed .nav-item {
|
||||
justify-content: center;
|
||||
padding: 8px 0;
|
||||
@@ -612,14 +642,6 @@
|
||||
.spinner-md .spinner-ring { width: 24px; height: 24px; }
|
||||
.spinner-lg .spinner-ring { width: 40px; height: 40px; }
|
||||
|
||||
.spinner-logo {
|
||||
animation: pulse 1.2s ease-in-out infinite;
|
||||
object-fit: contain;
|
||||
}
|
||||
.spinner-sm .spinner-logo { width: 16px; height: 16px; }
|
||||
.spinner-md .spinner-logo { width: 24px; height: 24px; }
|
||||
.spinner-lg .spinner-logo { width: 40px; height: 40px; }
|
||||
|
||||
/* Model selector */
|
||||
.model-selector {
|
||||
background: var(--color-bg-tertiary);
|
||||
@@ -2646,6 +2668,43 @@
|
||||
font-size: 0.625rem;
|
||||
}
|
||||
|
||||
/* Studio tabs */
|
||||
.studio-tabs {
|
||||
display: flex;
|
||||
gap: 0;
|
||||
border-bottom: 1px solid var(--color-border-subtle);
|
||||
padding: 0 var(--spacing-xl);
|
||||
background: var(--color-bg-primary);
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.studio-tab {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
background: none;
|
||||
border: none;
|
||||
padding: var(--spacing-sm) var(--spacing-md);
|
||||
font-size: 0.8125rem;
|
||||
font-family: inherit;
|
||||
color: var(--color-text-secondary);
|
||||
cursor: pointer;
|
||||
border-bottom: 2px solid transparent;
|
||||
transition: color var(--duration-fast), border-color var(--duration-fast);
|
||||
}
|
||||
|
||||
.studio-tab:hover {
|
||||
color: var(--color-text-primary);
|
||||
}
|
||||
|
||||
.studio-tab-active {
|
||||
color: var(--color-primary);
|
||||
border-bottom-color: var(--color-primary);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
/* Two-column layout for media generation pages */
|
||||
.media-layout {
|
||||
display: grid;
|
||||
|
||||
@@ -1,22 +1,8 @@
|
||||
import { useState } from 'react'
|
||||
import { apiUrl } from '../utils/basePath'
|
||||
|
||||
export default function LoadingSpinner({ size = 'md', className = '' }) {
|
||||
const sizeClass = size === 'sm' ? 'spinner-sm' : size === 'lg' ? 'spinner-lg' : 'spinner-md'
|
||||
const [imgFailed, setImgFailed] = useState(false)
|
||||
|
||||
return (
|
||||
<div className={`spinner ${sizeClass} ${className}`}>
|
||||
{imgFailed ? (
|
||||
<div className="spinner-ring" />
|
||||
) : (
|
||||
<img
|
||||
src={apiUrl('/static/logo.png')}
|
||||
alt=""
|
||||
className="spinner-logo"
|
||||
onError={() => setImgFailed(true)}
|
||||
/>
|
||||
)}
|
||||
<div className="spinner-ring" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,37 +1,57 @@
|
||||
import { useState, useEffect } from 'react'
|
||||
import { NavLink, useNavigate } from 'react-router-dom'
|
||||
import { NavLink, useNavigate, useLocation } from 'react-router-dom'
|
||||
import ThemeToggle from './ThemeToggle'
|
||||
import { useAuth } from '../context/AuthContext'
|
||||
import { apiUrl } from '../utils/basePath'
|
||||
|
||||
const COLLAPSED_KEY = 'localai_sidebar_collapsed'
|
||||
const SECTIONS_KEY = 'localai_sidebar_sections'
|
||||
|
||||
const mainItems = [
|
||||
const topItems = [
|
||||
{ path: '/app', icon: 'fas fa-home', label: 'Home' },
|
||||
{ path: '/app/models', icon: 'fas fa-download', label: 'Install Models', adminOnly: true },
|
||||
{ path: '/app/chat', icon: 'fas fa-comments', label: 'Chat' },
|
||||
{ path: '/app/image', icon: 'fas fa-image', label: 'Images' },
|
||||
{ path: '/app/video', icon: 'fas fa-video', label: 'Video' },
|
||||
{ path: '/app/tts', icon: 'fas fa-music', label: 'TTS' },
|
||||
{ path: '/app/sound', icon: 'fas fa-volume-high', label: 'Sound' },
|
||||
{ path: '/app/studio', icon: 'fas fa-palette', label: 'Studio' },
|
||||
{ path: '/app/talk', icon: 'fas fa-phone', label: 'Talk' },
|
||||
{ path: '/app/usage', icon: 'fas fa-chart-bar', label: 'Usage', authOnly: true },
|
||||
]
|
||||
|
||||
const agentItems = [
|
||||
{ path: '/app/agents', icon: 'fas fa-robot', label: 'Agents' },
|
||||
{ path: '/app/skills', icon: 'fas fa-wand-magic-sparkles', label: 'Skills' },
|
||||
{ path: '/app/collections', icon: 'fas fa-database', label: 'Memory' },
|
||||
{ path: '/app/agent-jobs', icon: 'fas fa-tasks', label: 'MCP CI Jobs', feature: 'mcp' },
|
||||
]
|
||||
|
||||
const systemItems = [
|
||||
{ path: '/app/users', icon: 'fas fa-users', label: 'Users', adminOnly: true, authOnly: true },
|
||||
{ path: '/app/backends', icon: 'fas fa-server', label: 'Backends', adminOnly: true },
|
||||
{ path: '/app/traces', icon: 'fas fa-chart-line', label: 'Traces', adminOnly: true },
|
||||
{ path: '/app/p2p', icon: 'fas fa-circle-nodes', label: 'Swarm', adminOnly: true },
|
||||
{ path: '/app/manage', icon: 'fas fa-desktop', label: 'System', adminOnly: true },
|
||||
{ path: '/app/settings', icon: 'fas fa-cog', label: 'Settings', adminOnly: true },
|
||||
const sections = [
|
||||
{
|
||||
id: 'tools',
|
||||
title: 'Tools',
|
||||
items: [
|
||||
{ path: '/app/fine-tune', icon: 'fas fa-graduation-cap', label: 'Fine-Tune', feature: 'fine_tuning' },
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 'agents',
|
||||
title: 'Agents',
|
||||
featureMap: {
|
||||
'/app/agents': 'agents',
|
||||
'/app/skills': 'skills',
|
||||
'/app/collections': 'collections',
|
||||
'/app/agent-jobs': 'mcp_jobs',
|
||||
},
|
||||
items: [
|
||||
{ path: '/app/agents', icon: 'fas fa-robot', label: 'Agents' },
|
||||
{ path: '/app/skills', icon: 'fas fa-wand-magic-sparkles', label: 'Skills' },
|
||||
{ path: '/app/collections', icon: 'fas fa-database', label: 'Memory' },
|
||||
{ path: '/app/agent-jobs', icon: 'fas fa-tasks', label: 'MCP CI Jobs', feature: 'mcp' },
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 'system',
|
||||
title: 'System',
|
||||
items: [
|
||||
{ path: '/app/usage', icon: 'fas fa-chart-bar', label: 'Usage', authOnly: true },
|
||||
{ path: '/app/users', icon: 'fas fa-users', label: 'Users', adminOnly: true, authOnly: true },
|
||||
{ path: '/app/backends', icon: 'fas fa-server', label: 'Backends', adminOnly: true },
|
||||
{ path: '/app/traces', icon: 'fas fa-chart-line', label: 'Traces', adminOnly: true },
|
||||
{ path: '/app/p2p', icon: 'fas fa-circle-nodes', label: 'Swarm', adminOnly: true },
|
||||
{ path: '/app/manage', icon: 'fas fa-desktop', label: 'System', adminOnly: true },
|
||||
{ path: '/app/settings', icon: 'fas fa-cog', label: 'Settings', adminOnly: true },
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
function NavItem({ item, onClose, collapsed }) {
|
||||
@@ -51,18 +71,47 @@ function NavItem({ item, onClose, collapsed }) {
|
||||
)
|
||||
}
|
||||
|
||||
function loadSectionState() {
|
||||
try {
|
||||
const stored = localStorage.getItem(SECTIONS_KEY)
|
||||
return stored ? JSON.parse(stored) : {}
|
||||
} catch (_) {
|
||||
return {}
|
||||
}
|
||||
}
|
||||
|
||||
function saveSectionState(state) {
|
||||
try { localStorage.setItem(SECTIONS_KEY, JSON.stringify(state)) } catch (_) { /* ignore */ }
|
||||
}
|
||||
|
||||
export default function Sidebar({ isOpen, onClose }) {
|
||||
const [features, setFeatures] = useState({})
|
||||
const [collapsed, setCollapsed] = useState(() => {
|
||||
try { return localStorage.getItem(COLLAPSED_KEY) === 'true' } catch (_) { return false }
|
||||
})
|
||||
const [openSections, setOpenSections] = useState(loadSectionState)
|
||||
const { isAdmin, authEnabled, user, logout, hasFeature } = useAuth()
|
||||
const navigate = useNavigate()
|
||||
const location = useLocation()
|
||||
|
||||
useEffect(() => {
|
||||
fetch(apiUrl('/api/features')).then(r => r.json()).then(setFeatures).catch(() => {})
|
||||
}, [])
|
||||
|
||||
// Auto-expand section containing the active route
|
||||
useEffect(() => {
|
||||
for (const section of sections) {
|
||||
const match = section.items.some(item => location.pathname.startsWith(item.path))
|
||||
if (match && !openSections[section.id]) {
|
||||
setOpenSections(prev => {
|
||||
const next = { ...prev, [section.id]: true }
|
||||
saveSectionState(next)
|
||||
return next
|
||||
})
|
||||
}
|
||||
}
|
||||
}, [location.pathname])
|
||||
|
||||
const toggleCollapse = () => {
|
||||
setCollapsed(prev => {
|
||||
const next = !prev
|
||||
@@ -72,17 +121,34 @@ export default function Sidebar({ isOpen, onClose }) {
|
||||
})
|
||||
}
|
||||
|
||||
const visibleMainItems = mainItems.filter(item => {
|
||||
if (item.adminOnly && !isAdmin) return false
|
||||
if (item.authOnly && !authEnabled) return false
|
||||
return true
|
||||
})
|
||||
const toggleSection = (id) => {
|
||||
setOpenSections(prev => {
|
||||
const next = { ...prev, [id]: !prev[id] }
|
||||
saveSectionState(next)
|
||||
return next
|
||||
})
|
||||
}
|
||||
|
||||
const visibleSystemItems = systemItems.filter(item => {
|
||||
const filterItem = (item) => {
|
||||
if (item.adminOnly && !isAdmin) return false
|
||||
if (item.authOnly && !authEnabled) return false
|
||||
if (item.feature && features[item.feature] === false) return false
|
||||
if (item.feature && !hasFeature(item.feature)) return false
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
const visibleTopItems = topItems.filter(filterItem)
|
||||
|
||||
const getVisibleSectionItems = (section) => {
|
||||
return section.items.filter(item => {
|
||||
if (!filterItem(item)) return false
|
||||
if (section.featureMap) {
|
||||
const featureName = section.featureMap[item.path]
|
||||
return featureName ? hasFeature(featureName) : isAdmin
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -104,57 +170,57 @@ export default function Sidebar({ isOpen, onClose }) {
|
||||
|
||||
{/* Navigation */}
|
||||
<nav className="sidebar-nav">
|
||||
{/* Main section */}
|
||||
{/* Top-level items */}
|
||||
<div className="sidebar-section">
|
||||
{visibleMainItems.map(item => (
|
||||
{visibleTopItems.map(item => (
|
||||
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Agents section (per-feature permissions) */}
|
||||
{features.agents !== false && (() => {
|
||||
const featureMap = {
|
||||
'/app/agents': 'agents',
|
||||
'/app/skills': 'skills',
|
||||
'/app/collections': 'collections',
|
||||
'/app/agent-jobs': 'mcp_jobs',
|
||||
}
|
||||
const visibleAgentItems = agentItems.filter(item => {
|
||||
if (item.feature && features[item.feature] === false) return false
|
||||
const featureName = featureMap[item.path]
|
||||
return featureName ? hasFeature(featureName) : isAdmin
|
||||
})
|
||||
if (visibleAgentItems.length === 0) return null
|
||||
{/* Collapsible sections */}
|
||||
{sections.map(section => {
|
||||
// For agents section, check global feature flag
|
||||
if (section.id === 'agents' && features.agents === false) return null
|
||||
|
||||
const visibleItems = getVisibleSectionItems(section)
|
||||
if (visibleItems.length === 0) return null
|
||||
|
||||
const isSectionOpen = openSections[section.id]
|
||||
const showItems = isSectionOpen || collapsed
|
||||
|
||||
return (
|
||||
<div className="sidebar-section">
|
||||
<div className="sidebar-section-title">Agents</div>
|
||||
{visibleAgentItems.map(item => (
|
||||
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
|
||||
))}
|
||||
<div key={section.id} className="sidebar-section">
|
||||
<button
|
||||
className={`sidebar-section-title sidebar-section-toggle ${isSectionOpen ? 'open' : ''}`}
|
||||
onClick={() => toggleSection(section.id)}
|
||||
title={collapsed ? section.title : undefined}
|
||||
>
|
||||
<span>{section.title}</span>
|
||||
<i className="fas fa-chevron-right sidebar-section-chevron" />
|
||||
</button>
|
||||
{showItems && (
|
||||
<div className="sidebar-section-items">
|
||||
{section.id === 'system' && (
|
||||
<a
|
||||
href={apiUrl('/swagger/index.html')}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="nav-item"
|
||||
title={collapsed ? 'API' : undefined}
|
||||
>
|
||||
<i className="fas fa-code nav-icon" />
|
||||
<span className="nav-label">API</span>
|
||||
<i className="fas fa-external-link-alt nav-external" />
|
||||
</a>
|
||||
)}
|
||||
{visibleItems.map(item => (
|
||||
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})()}
|
||||
|
||||
{/* System section */}
|
||||
<div className="sidebar-section">
|
||||
{visibleSystemItems.length > 0 && (
|
||||
<div className="sidebar-section-title">System</div>
|
||||
)}
|
||||
<a
|
||||
href={apiUrl('/swagger/index.html')}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="nav-item"
|
||||
title={collapsed ? 'API' : undefined}
|
||||
>
|
||||
<i className="fas fa-code nav-icon" />
|
||||
<span className="nav-label">API</span>
|
||||
<i className="fas fa-external-link-alt nav-external" />
|
||||
</a>
|
||||
{visibleSystemItems.map(item => (
|
||||
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
|
||||
))}
|
||||
</div>
|
||||
})}
|
||||
</nav>
|
||||
|
||||
{/* Footer */}
|
||||
|
||||
1525
core/http/react-ui/src/pages/FineTune.jsx
Normal file
1525
core/http/react-ui/src/pages/FineTune.jsx
Normal file
File diff suppressed because it is too large
Load Diff
48
core/http/react-ui/src/pages/Studio.jsx
Normal file
48
core/http/react-ui/src/pages/Studio.jsx
Normal file
@@ -0,0 +1,48 @@
|
||||
import { useSearchParams } from 'react-router-dom'
|
||||
import ImageGen from './ImageGen'
|
||||
import VideoGen from './VideoGen'
|
||||
import TTS from './TTS'
|
||||
import Sound from './Sound'
|
||||
|
||||
const TABS = [
|
||||
{ key: 'images', label: 'Images', icon: 'fas fa-image' },
|
||||
{ key: 'video', label: 'Video', icon: 'fas fa-video' },
|
||||
{ key: 'tts', label: 'TTS', icon: 'fas fa-headphones' },
|
||||
{ key: 'sound', label: 'Sound', icon: 'fas fa-music' },
|
||||
]
|
||||
|
||||
const TAB_COMPONENTS = {
|
||||
images: ImageGen,
|
||||
video: VideoGen,
|
||||
tts: TTS,
|
||||
sound: Sound,
|
||||
}
|
||||
|
||||
export default function Studio() {
|
||||
const [searchParams, setSearchParams] = useSearchParams()
|
||||
const activeTab = searchParams.get('tab') || 'images'
|
||||
|
||||
const setTab = (key) => {
|
||||
setSearchParams({ tab: key }, { replace: true })
|
||||
}
|
||||
|
||||
const ActiveComponent = TAB_COMPONENTS[activeTab] || ImageGen
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="studio-tabs">
|
||||
{TABS.map(tab => (
|
||||
<button
|
||||
key={tab.key}
|
||||
className={`studio-tab${activeTab === tab.key ? ' studio-tab-active' : ''}`}
|
||||
onClick={() => setTab(tab.key)}
|
||||
>
|
||||
<i className={tab.icon} />
|
||||
<span>{tab.label}</span>
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
<ActiveComponent />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -45,9 +45,11 @@ function PermissionSummary({ user, onClick }) {
|
||||
const perms = user.permissions || {}
|
||||
const apiFeatures = ['chat', 'images', 'audio_speech', 'audio_transcription', 'vad', 'detection', 'video', 'embeddings', 'sound']
|
||||
const agentFeatures = ['agents', 'skills', 'collections', 'mcp_jobs']
|
||||
const generalFeatures = ['fine_tuning']
|
||||
|
||||
const apiOn = apiFeatures.filter(f => perms[f] !== false && (perms[f] === true || perms[f] === undefined)).length
|
||||
const agentOn = agentFeatures.filter(f => perms[f]).length
|
||||
const generalOn = generalFeatures.filter(f => perms[f]).length
|
||||
|
||||
const modelRestricted = user.allowed_models?.enabled
|
||||
|
||||
@@ -58,7 +60,7 @@ function PermissionSummary({ user, onClick }) {
|
||||
title="Edit permissions"
|
||||
>
|
||||
<i className="fas fa-shield-halved" />
|
||||
{apiOn}/{apiFeatures.length} API, {agentOn}/{agentFeatures.length} Agent
|
||||
{apiOn}/{apiFeatures.length} API, {agentOn}/{agentFeatures.length} Agent, {generalOn}/{generalFeatures.length} Features
|
||||
{modelRestricted && ' | Models restricted'}
|
||||
</button>
|
||||
)
|
||||
@@ -71,6 +73,7 @@ function PermissionsModal({ user, featureMeta, availableModels, onClose, onSave,
|
||||
|
||||
const apiFeatures = featureMeta?.api_features || []
|
||||
const agentFeatures = featureMeta?.agent_features || []
|
||||
const generalFeatures = featureMeta?.general_features || []
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (e) => {
|
||||
@@ -189,6 +192,33 @@ function PermissionsModal({ user, featureMeta, availableModels, onClose, onSave,
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* General Features */}
|
||||
{generalFeatures.length > 0 && (
|
||||
<div className="perm-section">
|
||||
<div className="perm-section-header">
|
||||
<strong className="perm-section-title">
|
||||
<i className="fas fa-sliders" />
|
||||
Features
|
||||
</strong>
|
||||
<div className="action-group">
|
||||
<button className="btn btn-sm btn-secondary perm-btn-all-none" onClick={() => setAllFeatures(generalFeatures, true)}>All</button>
|
||||
<button className="btn btn-sm btn-secondary perm-btn-all-none" onClick={() => setAllFeatures(generalFeatures, false)}>None</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="perm-grid">
|
||||
{generalFeatures.map(f => (
|
||||
<button
|
||||
key={f.key}
|
||||
className={`btn btn-sm ${permissions[f.key] ? 'btn-primary' : 'btn-secondary'} perm-btn-feature`}
|
||||
onClick={() => toggleFeature(f.key)}
|
||||
>
|
||||
{f.label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Model Access */}
|
||||
<div className="perm-section">
|
||||
<div className="perm-section-header">
|
||||
@@ -510,6 +540,9 @@ export default function Users() {
|
||||
{ key: 'collections', label: 'Collections', default: false },
|
||||
{ key: 'mcp_jobs', label: 'MCP CI Jobs', default: false },
|
||||
],
|
||||
general_features: [
|
||||
{ key: 'fine_tuning', label: 'Fine-Tuning', default: false },
|
||||
],
|
||||
})
|
||||
}
|
||||
}, [])
|
||||
|
||||
@@ -31,6 +31,8 @@ import ImportModel from './pages/ImportModel'
|
||||
import BackendLogs from './pages/BackendLogs'
|
||||
import Explorer from './pages/Explorer'
|
||||
import Login from './pages/Login'
|
||||
import FineTune from './pages/FineTune'
|
||||
import Studio from './pages/Studio'
|
||||
import NotFound from './pages/NotFound'
|
||||
import Usage from './pages/Usage'
|
||||
import Users from './pages/Users'
|
||||
@@ -44,6 +46,7 @@ function BrowseRedirect() {
|
||||
return <Navigate to={`/app/${splat || ''}`} replace />
|
||||
}
|
||||
|
||||
|
||||
function Admin({ children }) {
|
||||
return <RequireAdmin>{children}</RequireAdmin>
|
||||
}
|
||||
@@ -65,6 +68,7 @@ const appChildren = [
|
||||
{ path: 'tts/:model', element: <TTS /> },
|
||||
{ path: 'sound', element: <Sound /> },
|
||||
{ path: 'sound/:model', element: <Sound /> },
|
||||
{ path: 'studio', element: <Studio /> },
|
||||
{ path: 'talk', element: <Talk /> },
|
||||
{ path: 'usage', element: <Usage /> },
|
||||
{ path: 'account', element: <Account /> },
|
||||
@@ -90,6 +94,7 @@ const appChildren = [
|
||||
{ path: 'agent-jobs/tasks/:id', element: <Feature feature="mcp_jobs"><AgentTaskDetails /></Feature> },
|
||||
{ path: 'agent-jobs/tasks/:id/edit', element: <Feature feature="mcp_jobs"><AgentTaskDetails /></Feature> },
|
||||
{ path: 'agent-jobs/jobs/:id', element: <Feature feature="mcp_jobs"><AgentJobDetails /></Feature> },
|
||||
{ path: 'fine-tune', element: <Feature feature="fine_tuning"><FineTune /></Feature> },
|
||||
{ path: 'model-editor/:name', element: <Admin><ModelEditor /></Admin> },
|
||||
{ path: 'pipeline-editor', element: <Admin><PipelineEditor /></Admin> },
|
||||
{ path: 'pipeline-editor/:name', element: <Admin><PipelineEditor /></Admin> },
|
||||
|
||||
19
core/http/react-ui/src/utils/api.js
vendored
19
core/http/react-ui/src/utils/api.js
vendored
@@ -380,6 +380,25 @@ export const apiKeysApi = {
|
||||
revoke: (id) => fetchJSON(`/api/auth/api-keys/${encodeURIComponent(id)}`, { method: 'DELETE' }),
|
||||
}
|
||||
|
||||
// Fine-tuning API
|
||||
export const fineTuneApi = {
|
||||
listBackends: () => fetchJSON('/api/fine-tuning/backends'),
|
||||
startJob: (data) => postJSON('/api/fine-tuning/jobs', data),
|
||||
listJobs: () => fetchJSON('/api/fine-tuning/jobs'),
|
||||
getJob: (id) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}`),
|
||||
stopJob: (id, saveCheckpoint) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}/stop?save_checkpoint=${saveCheckpoint ? 'true' : 'false'}`, { method: 'POST' }),
|
||||
deleteJob: (id) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}`, { method: 'DELETE' }),
|
||||
listCheckpoints: (id) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}/checkpoints`),
|
||||
exportModel: (id, data) => postJSON(`/api/fine-tuning/jobs/${enc(id)}/export`, data),
|
||||
uploadDataset: (file) => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
return fetch(apiUrl('/api/fine-tuning/datasets'), { method: 'POST', body: formData }).then(handleResponse)
|
||||
},
|
||||
progressUrl: (id) => apiUrl(`/api/fine-tuning/jobs/${enc(id)}/progress`),
|
||||
downloadUrl: (id) => apiUrl(`/api/fine-tuning/jobs/${enc(id)}/download`),
|
||||
}
|
||||
|
||||
// File to base64 helper
|
||||
export function fileToBase64(file) {
|
||||
return new Promise((resolve, reject) => {
|
||||
|
||||
@@ -777,9 +777,10 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"agent_features": auth.AgentFeatureMetas(),
|
||||
"api_features": auth.APIFeatureMetas(),
|
||||
"models": modelNames,
|
||||
"agent_features": auth.AgentFeatureMetas(),
|
||||
"general_features": auth.GeneralFeatureMetas(),
|
||||
"api_features": auth.APIFeatureMetas(),
|
||||
"models": modelNames,
|
||||
})
|
||||
}, adminMw)
|
||||
|
||||
|
||||
42
core/http/routes/finetuning.go
Normal file
42
core/http/routes/finetuning.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// RegisterFineTuningRoutes registers fine-tuning API routes.
|
||||
func RegisterFineTuningRoutes(e *echo.Echo, ftService *services.FineTuneService, appConfig *config.ApplicationConfig, fineTuningMw echo.MiddlewareFunc) {
|
||||
if ftService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Service readiness middleware
|
||||
readyMw := func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if ftService == nil {
|
||||
return c.JSON(http.StatusServiceUnavailable, map[string]string{
|
||||
"error": "fine-tuning service is not available",
|
||||
})
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
|
||||
ft := e.Group("/api/fine-tuning", readyMw, fineTuningMw)
|
||||
ft.GET("/backends", localai.ListFineTuneBackendsEndpoint(appConfig))
|
||||
ft.POST("/jobs", localai.StartFineTuneJobEndpoint(ftService))
|
||||
ft.GET("/jobs", localai.ListFineTuneJobsEndpoint(ftService))
|
||||
ft.GET("/jobs/:id", localai.GetFineTuneJobEndpoint(ftService))
|
||||
ft.POST("/jobs/:id/stop", localai.StopFineTuneJobEndpoint(ftService))
|
||||
ft.DELETE("/jobs/:id", localai.DeleteFineTuneJobEndpoint(ftService))
|
||||
ft.GET("/jobs/:id/progress", localai.FineTuneProgressEndpoint(ftService))
|
||||
ft.GET("/jobs/:id/checkpoints", localai.ListCheckpointsEndpoint(ftService))
|
||||
ft.POST("/jobs/:id/export", localai.ExportModelEndpoint(ftService))
|
||||
ft.GET("/jobs/:id/download", localai.DownloadExportedModelEndpoint(ftService))
|
||||
ft.POST("/datasets", localai.UploadDatasetEndpoint(ftService))
|
||||
}
|
||||
@@ -134,8 +134,9 @@ func RegisterLocalAIRoutes(router *echo.Echo,
|
||||
|
||||
router.GET("/api/features", func(c echo.Context) error {
|
||||
return c.JSON(200, map[string]bool{
|
||||
"agents": appConfig.AgentPool.Enabled,
|
||||
"mcp": !appConfig.DisableMCP,
|
||||
"agents": appConfig.AgentPool.Enabled,
|
||||
"mcp": !appConfig.DisableMCP,
|
||||
"fine_tuning": appConfig.FineTuning.Enabled,
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
111
core/schema/finetune.go
Normal file
111
core/schema/finetune.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package schema
|
||||
|
||||
// RewardFunctionSpec defines a reward function for GRPO training.
|
||||
type RewardFunctionSpec struct {
|
||||
Type string `json:"type"` // "builtin" or "inline"
|
||||
Name string `json:"name"`
|
||||
Code string `json:"code,omitempty"` // inline only
|
||||
Params map[string]string `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// FineTuneJobRequest is the REST API request to start a fine-tuning job.
|
||||
type FineTuneJobRequest struct {
|
||||
Model string `json:"model"`
|
||||
Backend string `json:"backend"` // "trl"
|
||||
TrainingType string `json:"training_type,omitempty"` // lora, loha, lokr, full
|
||||
TrainingMethod string `json:"training_method,omitempty"` // sft, dpo, grpo, rloo, reward, kto, orpo
|
||||
|
||||
// Adapter config
|
||||
AdapterRank int32 `json:"adapter_rank,omitempty"`
|
||||
AdapterAlpha int32 `json:"adapter_alpha,omitempty"`
|
||||
AdapterDropout float32 `json:"adapter_dropout,omitempty"`
|
||||
TargetModules []string `json:"target_modules,omitempty"`
|
||||
|
||||
// Training hyperparameters
|
||||
LearningRate float32 `json:"learning_rate,omitempty"`
|
||||
NumEpochs int32 `json:"num_epochs,omitempty"`
|
||||
BatchSize int32 `json:"batch_size,omitempty"`
|
||||
GradientAccumulationSteps int32 `json:"gradient_accumulation_steps,omitempty"`
|
||||
WarmupSteps int32 `json:"warmup_steps,omitempty"`
|
||||
MaxSteps int32 `json:"max_steps,omitempty"`
|
||||
SaveSteps int32 `json:"save_steps,omitempty"`
|
||||
WeightDecay float32 `json:"weight_decay,omitempty"`
|
||||
GradientCheckpointing bool `json:"gradient_checkpointing,omitempty"`
|
||||
Optimizer string `json:"optimizer,omitempty"`
|
||||
Seed int32 `json:"seed,omitempty"`
|
||||
MixedPrecision string `json:"mixed_precision,omitempty"`
|
||||
|
||||
// Dataset
|
||||
DatasetSource string `json:"dataset_source"`
|
||||
DatasetSplit string `json:"dataset_split,omitempty"`
|
||||
|
||||
// Resume from a checkpoint
|
||||
ResumeFromCheckpoint string `json:"resume_from_checkpoint,omitempty"`
|
||||
|
||||
// GRPO reward functions
|
||||
RewardFunctions []RewardFunctionSpec `json:"reward_functions,omitempty"`
|
||||
|
||||
// Backend-specific and method-specific options
|
||||
ExtraOptions map[string]string `json:"extra_options,omitempty"`
|
||||
}
|
||||
|
||||
// FineTuneJob represents a fine-tuning job with its current state.
|
||||
type FineTuneJob struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Backend string `json:"backend"`
|
||||
ModelID string `json:"model_id,omitempty"` // backend model loader ID
|
||||
TrainingType string `json:"training_type"`
|
||||
TrainingMethod string `json:"training_method"`
|
||||
Status string `json:"status"` // queued, loading_model, loading_dataset, training, saving, completed, failed, stopped
|
||||
Message string `json:"message,omitempty"`
|
||||
OutputDir string `json:"output_dir"`
|
||||
ExtraOptions map[string]string `json:"extra_options,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
|
||||
// Export state (tracked separately from training status)
|
||||
ExportStatus string `json:"export_status,omitempty"` // "", "exporting", "completed", "failed"
|
||||
ExportMessage string `json:"export_message,omitempty"`
|
||||
ExportModelName string `json:"export_model_name,omitempty"` // registered model name after export
|
||||
|
||||
// Full config for resume/reuse
|
||||
Config *FineTuneJobRequest `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
// FineTuneJobResponse is the REST API response when creating a job.
|
||||
type FineTuneJobResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// FineTuneProgressEvent is an SSE event for training progress.
|
||||
type FineTuneProgressEvent struct {
|
||||
JobID string `json:"job_id"`
|
||||
CurrentStep int32 `json:"current_step"`
|
||||
TotalSteps int32 `json:"total_steps"`
|
||||
CurrentEpoch float32 `json:"current_epoch"`
|
||||
TotalEpochs float32 `json:"total_epochs"`
|
||||
Loss float32 `json:"loss"`
|
||||
LearningRate float32 `json:"learning_rate"`
|
||||
GradNorm float32 `json:"grad_norm"`
|
||||
EvalLoss float32 `json:"eval_loss"`
|
||||
EtaSeconds float32 `json:"eta_seconds"`
|
||||
ProgressPercent float32 `json:"progress_percent"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
CheckpointPath string `json:"checkpoint_path,omitempty"`
|
||||
SamplePath string `json:"sample_path,omitempty"`
|
||||
ExtraMetrics map[string]float32 `json:"extra_metrics,omitempty"`
|
||||
}
|
||||
|
||||
// ExportRequest is the REST API request to export a model.
|
||||
type ExportRequest struct {
|
||||
Name string `json:"name,omitempty"` // model name for LocalAI (auto-generated if empty)
|
||||
CheckpointPath string `json:"checkpoint_path"`
|
||||
ExportFormat string `json:"export_format"` // lora, merged_16bit, merged_4bit, gguf
|
||||
QuantizationMethod string `json:"quantization_method"` // for GGUF: q4_k_m, q5_k_m, q8_0, f16
|
||||
Model string `json:"model,omitempty"` // base model name for merge
|
||||
ExtraOptions map[string]string `json:"extra_options,omitempty"`
|
||||
}
|
||||
@@ -1042,7 +1042,7 @@ func (s *AgentPoolService) CreateCollection(name string) error {
|
||||
return s.collectionsBackend.CreateCollection(name)
|
||||
}
|
||||
|
||||
func (s *AgentPoolService) UploadToCollection(collection, filename string, fileBody io.Reader) error {
|
||||
func (s *AgentPoolService) UploadToCollection(collection, filename string, fileBody io.Reader) (string, error) {
|
||||
return s.collectionsBackend.Upload(collection, filename, fileBody)
|
||||
}
|
||||
|
||||
@@ -1554,10 +1554,10 @@ func (s *AgentPoolService) CreateCollectionForUser(userID, name string) error {
|
||||
}
|
||||
|
||||
// UploadToCollectionForUser uploads to a collection for a specific user.
|
||||
func (s *AgentPoolService) UploadToCollectionForUser(userID, collection, filename string, fileBody io.Reader) error {
|
||||
func (s *AgentPoolService) UploadToCollectionForUser(userID, collection, filename string, fileBody io.Reader) (string, error) {
|
||||
backend, err := s.CollectionsBackendForUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
return backend.Upload(collection, filename, fileBody)
|
||||
}
|
||||
|
||||
700
core/services/finetune.go
Normal file
700
core/services/finetune.go
Normal file
@@ -0,0 +1,700 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// FineTuneService manages fine-tuning jobs and their lifecycle.
|
||||
type FineTuneService struct {
|
||||
appConfig *config.ApplicationConfig
|
||||
modelLoader *model.ModelLoader
|
||||
configLoader *config.ModelConfigLoader
|
||||
|
||||
mu sync.Mutex
|
||||
jobs map[string]*schema.FineTuneJob
|
||||
}
|
||||
|
||||
// NewFineTuneService creates a new FineTuneService.
|
||||
func NewFineTuneService(
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelLoader *model.ModelLoader,
|
||||
configLoader *config.ModelConfigLoader,
|
||||
) *FineTuneService {
|
||||
s := &FineTuneService{
|
||||
appConfig: appConfig,
|
||||
modelLoader: modelLoader,
|
||||
configLoader: configLoader,
|
||||
jobs: make(map[string]*schema.FineTuneJob),
|
||||
}
|
||||
s.loadAllJobs()
|
||||
return s
|
||||
}
|
||||
|
||||
// fineTuneBaseDir returns the base directory for fine-tune job data.
|
||||
func (s *FineTuneService) fineTuneBaseDir() string {
|
||||
return filepath.Join(s.appConfig.DataPath, "fine-tune")
|
||||
}
|
||||
|
||||
// jobDir returns the directory for a specific job.
|
||||
func (s *FineTuneService) jobDir(jobID string) string {
|
||||
return filepath.Join(s.fineTuneBaseDir(), jobID)
|
||||
}
|
||||
|
||||
// saveJobState persists a job's state to disk as state.json.
|
||||
func (s *FineTuneService) saveJobState(job *schema.FineTuneJob) {
|
||||
dir := s.jobDir(job.ID)
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
xlog.Error("Failed to create job directory", "job_id", job.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(job, "", " ")
|
||||
if err != nil {
|
||||
xlog.Error("Failed to marshal job state", "job_id", job.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
statePath := filepath.Join(dir, "state.json")
|
||||
if err := os.WriteFile(statePath, data, 0640); err != nil {
|
||||
xlog.Error("Failed to write job state", "job_id", job.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// loadAllJobs scans the fine-tune directory for persisted jobs and loads them.
|
||||
func (s *FineTuneService) loadAllJobs() {
|
||||
baseDir := s.fineTuneBaseDir()
|
||||
entries, err := os.ReadDir(baseDir)
|
||||
if err != nil {
|
||||
// Directory doesn't exist yet — that's fine
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
statePath := filepath.Join(baseDir, entry.Name(), "state.json")
|
||||
data, err := os.ReadFile(statePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var job schema.FineTuneJob
|
||||
if err := json.Unmarshal(data, &job); err != nil {
|
||||
xlog.Warn("Failed to parse job state", "path", statePath, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Jobs that were running when we shut down are now stale
|
||||
if job.Status == "queued" || job.Status == "loading_model" || job.Status == "loading_dataset" || job.Status == "training" || job.Status == "saving" {
|
||||
job.Status = "stopped"
|
||||
job.Message = "Server restarted while job was running"
|
||||
}
|
||||
|
||||
// Exports that were in progress are now stale
|
||||
if job.ExportStatus == "exporting" {
|
||||
job.ExportStatus = "failed"
|
||||
job.ExportMessage = "Server restarted while export was running"
|
||||
}
|
||||
|
||||
s.jobs[job.ID] = &job
|
||||
}
|
||||
|
||||
if len(s.jobs) > 0 {
|
||||
xlog.Info("Loaded persisted fine-tune jobs", "count", len(s.jobs))
|
||||
}
|
||||
}
|
||||
|
||||
// StartJob starts a new fine-tuning job.
|
||||
func (s *FineTuneService) StartJob(ctx context.Context, userID string, req schema.FineTuneJobRequest) (*schema.FineTuneJobResponse, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
jobID := uuid.New().String()
|
||||
|
||||
backendName := req.Backend
|
||||
if backendName == "" {
|
||||
backendName = "trl"
|
||||
}
|
||||
|
||||
// Always use DataPath for output — not user-configurable
|
||||
outputDir := filepath.Join(s.fineTuneBaseDir(), jobID)
|
||||
|
||||
// Build gRPC request
|
||||
grpcReq := &pb.FineTuneRequest{
|
||||
Model: req.Model,
|
||||
TrainingType: req.TrainingType,
|
||||
TrainingMethod: req.TrainingMethod,
|
||||
AdapterRank: req.AdapterRank,
|
||||
AdapterAlpha: req.AdapterAlpha,
|
||||
AdapterDropout: req.AdapterDropout,
|
||||
TargetModules: req.TargetModules,
|
||||
LearningRate: req.LearningRate,
|
||||
NumEpochs: req.NumEpochs,
|
||||
BatchSize: req.BatchSize,
|
||||
GradientAccumulationSteps: req.GradientAccumulationSteps,
|
||||
WarmupSteps: req.WarmupSteps,
|
||||
MaxSteps: req.MaxSteps,
|
||||
SaveSteps: req.SaveSteps,
|
||||
WeightDecay: req.WeightDecay,
|
||||
GradientCheckpointing: req.GradientCheckpointing,
|
||||
Optimizer: req.Optimizer,
|
||||
Seed: req.Seed,
|
||||
MixedPrecision: req.MixedPrecision,
|
||||
DatasetSource: req.DatasetSource,
|
||||
DatasetSplit: req.DatasetSplit,
|
||||
OutputDir: outputDir,
|
||||
JobId: jobID,
|
||||
ResumeFromCheckpoint: req.ResumeFromCheckpoint,
|
||||
ExtraOptions: req.ExtraOptions,
|
||||
}
|
||||
|
||||
// Serialize reward functions into extra_options for the backend
|
||||
if len(req.RewardFunctions) > 0 {
|
||||
rfJSON, err := json.Marshal(req.RewardFunctions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize reward functions: %w", err)
|
||||
}
|
||||
if grpcReq.ExtraOptions == nil {
|
||||
grpcReq.ExtraOptions = make(map[string]string)
|
||||
}
|
||||
grpcReq.ExtraOptions["reward_funcs"] = string(rfJSON)
|
||||
}
|
||||
|
||||
// Load the fine-tuning backend (per-job model ID so multiple jobs can run concurrently)
|
||||
modelID := backendName + "-finetune-" + jobID
|
||||
backendModel, err := s.modelLoader.Load(
|
||||
model.WithBackendString(backendName),
|
||||
model.WithModel(backendName),
|
||||
model.WithModelID(modelID),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load backend %s: %w", backendName, err)
|
||||
}
|
||||
|
||||
// Start fine-tuning via gRPC
|
||||
result, err := backendModel.StartFineTune(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start fine-tuning: %w", err)
|
||||
}
|
||||
if !result.Success {
|
||||
return nil, fmt.Errorf("fine-tuning failed to start: %s", result.Message)
|
||||
}
|
||||
|
||||
// Track the job
|
||||
job := &schema.FineTuneJob{
|
||||
ID: jobID,
|
||||
UserID: userID,
|
||||
Model: req.Model,
|
||||
Backend: backendName,
|
||||
ModelID: modelID,
|
||||
TrainingType: req.TrainingType,
|
||||
TrainingMethod: req.TrainingMethod,
|
||||
Status: "queued",
|
||||
OutputDir: outputDir,
|
||||
ExtraOptions: req.ExtraOptions,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
Config: &req,
|
||||
}
|
||||
s.jobs[jobID] = job
|
||||
s.saveJobState(job)
|
||||
|
||||
return &schema.FineTuneJobResponse{
|
||||
ID: jobID,
|
||||
Status: "queued",
|
||||
Message: result.Message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetJob returns a fine-tuning job by ID.
|
||||
func (s *FineTuneService) GetJob(userID, jobID string) (*schema.FineTuneJob, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// ListJobs returns all jobs for a user, sorted by creation time (newest first).
|
||||
func (s *FineTuneService) ListJobs(userID string) []*schema.FineTuneJob {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var result []*schema.FineTuneJob
|
||||
for _, job := range s.jobs {
|
||||
if userID == "" || job.UserID == userID {
|
||||
result = append(result, job)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].CreatedAt > result[j].CreatedAt
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// StopJob stops a running fine-tuning job.
|
||||
func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, saveCheckpoint bool) error {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Kill the backend process directly
|
||||
stopModelID := job.ModelID
|
||||
if stopModelID == "" {
|
||||
stopModelID = job.Backend + "-finetune"
|
||||
}
|
||||
s.modelLoader.ShutdownModel(stopModelID)
|
||||
|
||||
s.mu.Lock()
|
||||
job.Status = "stopped"
|
||||
job.Message = "Training stopped by user"
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteJob removes a fine-tuning job and its associated data from disk.
|
||||
func (s *FineTuneService) DeleteJob(userID, jobID string) error {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
|
||||
// Reject deletion of actively running jobs
|
||||
activeStatuses := map[string]bool{
|
||||
"queued": true, "loading_model": true, "loading_dataset": true,
|
||||
"training": true, "saving": true,
|
||||
}
|
||||
if activeStatuses[job.Status] {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("cannot delete job %s: currently %s (stop it first)", jobID, job.Status)
|
||||
}
|
||||
if job.ExportStatus == "exporting" {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("cannot delete job %s: export in progress", jobID)
|
||||
}
|
||||
|
||||
exportModelName := job.ExportModelName
|
||||
delete(s.jobs, jobID)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Remove job directory (state.json, checkpoints, output)
|
||||
jobDir := s.jobDir(jobID)
|
||||
if err := os.RemoveAll(jobDir); err != nil {
|
||||
xlog.Warn("Failed to remove job directory", "job_id", jobID, "path", jobDir, "error", err)
|
||||
}
|
||||
|
||||
// If an exported model exists, clean it up too
|
||||
if exportModelName != "" {
|
||||
modelsPath := s.appConfig.SystemState.Model.ModelsPath
|
||||
modelDir := filepath.Join(modelsPath, exportModelName)
|
||||
configPath := filepath.Join(modelsPath, exportModelName+".yaml")
|
||||
|
||||
if err := os.RemoveAll(modelDir); err != nil {
|
||||
xlog.Warn("Failed to remove exported model directory", "path", modelDir, "error", err)
|
||||
}
|
||||
if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) {
|
||||
xlog.Warn("Failed to remove exported model config", "path", configPath, "error", err)
|
||||
}
|
||||
|
||||
// Reload model configs
|
||||
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
xlog.Warn("Failed to reload configs after delete", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
xlog.Info("Deleted fine-tune job", "job_id", jobID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StreamProgress opens a gRPC progress stream and calls the callback for each update.
|
||||
func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.FineTuneProgressEvent)) error {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
streamModelID := job.ModelID
|
||||
if streamModelID == "" {
|
||||
streamModelID = job.Backend + "-finetune"
|
||||
}
|
||||
backendModel, err := s.modelLoader.Load(
|
||||
model.WithBackendString(job.Backend),
|
||||
model.WithModel(job.Backend),
|
||||
model.WithModelID(streamModelID),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load backend: %w", err)
|
||||
}
|
||||
|
||||
return backendModel.FineTuneProgress(ctx, &pb.FineTuneProgressRequest{
|
||||
JobId: jobID,
|
||||
}, func(update *pb.FineTuneProgressUpdate) {
|
||||
// Update job status and persist
|
||||
s.mu.Lock()
|
||||
if j, ok := s.jobs[jobID]; ok {
|
||||
// Don't let progress updates overwrite terminal states
|
||||
isTerminal := j.Status == "stopped" || j.Status == "completed" || j.Status == "failed"
|
||||
if !isTerminal {
|
||||
j.Status = update.Status
|
||||
}
|
||||
if update.Message != "" {
|
||||
j.Message = update.Message
|
||||
}
|
||||
s.saveJobState(j)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Convert extra metrics
|
||||
extraMetrics := make(map[string]float32)
|
||||
for k, v := range update.ExtraMetrics {
|
||||
extraMetrics[k] = v
|
||||
}
|
||||
|
||||
event := &schema.FineTuneProgressEvent{
|
||||
JobID: update.JobId,
|
||||
CurrentStep: update.CurrentStep,
|
||||
TotalSteps: update.TotalSteps,
|
||||
CurrentEpoch: update.CurrentEpoch,
|
||||
TotalEpochs: update.TotalEpochs,
|
||||
Loss: update.Loss,
|
||||
LearningRate: update.LearningRate,
|
||||
GradNorm: update.GradNorm,
|
||||
EvalLoss: update.EvalLoss,
|
||||
EtaSeconds: update.EtaSeconds,
|
||||
ProgressPercent: update.ProgressPercent,
|
||||
Status: update.Status,
|
||||
Message: update.Message,
|
||||
CheckpointPath: update.CheckpointPath,
|
||||
SamplePath: update.SamplePath,
|
||||
ExtraMetrics: extraMetrics,
|
||||
}
|
||||
callback(event)
|
||||
})
|
||||
}
|
||||
|
||||
// ListCheckpoints lists checkpoints for a job.
|
||||
func (s *FineTuneService) ListCheckpoints(ctx context.Context, userID, jobID string) ([]*pb.CheckpointInfo, error) {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
ckptModelID := job.ModelID
|
||||
if ckptModelID == "" {
|
||||
ckptModelID = job.Backend + "-finetune"
|
||||
}
|
||||
backendModel, err := s.modelLoader.Load(
|
||||
model.WithBackendString(job.Backend),
|
||||
model.WithModel(job.Backend),
|
||||
model.WithModelID(ckptModelID),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load backend: %w", err)
|
||||
}
|
||||
|
||||
resp, err := backendModel.ListCheckpoints(ctx, &pb.ListCheckpointsRequest{
|
||||
OutputDir: job.OutputDir,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list checkpoints: %w", err)
|
||||
}
|
||||
|
||||
return resp.Checkpoints, nil
|
||||
}
|
||||
|
||||
// sanitizeModelName replaces non-alphanumeric characters with hyphens and lowercases.
|
||||
func sanitizeModelName(s string) string {
|
||||
re := regexp.MustCompile(`[^a-zA-Z0-9\-]`)
|
||||
s = re.ReplaceAllString(s, "-")
|
||||
s = regexp.MustCompile(`-+`).ReplaceAllString(s, "-")
|
||||
s = strings.Trim(s, "-")
|
||||
return strings.ToLower(s)
|
||||
}
|
||||
|
||||
// ExportModel starts an async model export from a checkpoint and returns the intended model name immediately.
|
||||
func (s *FineTuneService) ExportModel(ctx context.Context, userID, jobID string, req schema.ExportRequest) (string, error) {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if job.ExportStatus == "exporting" {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("export already in progress for job %s", jobID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Compute model name
|
||||
modelName := req.Name
|
||||
if modelName == "" {
|
||||
base := sanitizeModelName(job.Model)
|
||||
if base == "" {
|
||||
base = "model"
|
||||
}
|
||||
shortID := jobID
|
||||
if len(shortID) > 8 {
|
||||
shortID = shortID[:8]
|
||||
}
|
||||
modelName = base + "-ft-" + shortID
|
||||
}
|
||||
|
||||
// Compute output path in models directory
|
||||
modelsPath := s.appConfig.SystemState.Model.ModelsPath
|
||||
outputPath := filepath.Join(modelsPath, modelName)
|
||||
|
||||
// Check for name collision (synchronous — fast validation)
|
||||
configPath := filepath.Join(modelsPath, modelName+".yaml")
|
||||
if err := utils.VerifyPath(modelName+".yaml", modelsPath); err != nil {
|
||||
return "", fmt.Errorf("invalid model name: %w", err)
|
||||
}
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
return "", fmt.Errorf("model %q already exists, choose a different name", modelName)
|
||||
}
|
||||
|
||||
// Create output directory
|
||||
if err := os.MkdirAll(outputPath, 0750); err != nil {
|
||||
return "", fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
|
||||
// Set export status to "exporting" and persist
|
||||
s.mu.Lock()
|
||||
job.ExportStatus = "exporting"
|
||||
job.ExportMessage = ""
|
||||
job.ExportModelName = ""
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Launch the export in a background goroutine
|
||||
go func() {
|
||||
s.setExportMessage(job, "Loading export backend...")
|
||||
|
||||
exportModelID := job.ModelID
|
||||
if exportModelID == "" {
|
||||
exportModelID = job.Backend + "-finetune"
|
||||
}
|
||||
backendModel, err := s.modelLoader.Load(
|
||||
model.WithBackendString(job.Backend),
|
||||
model.WithModel(job.Backend),
|
||||
model.WithModelID(exportModelID),
|
||||
)
|
||||
if err != nil {
|
||||
s.setExportFailed(job, fmt.Sprintf("failed to load backend: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Merge job's extra_options (contains hf_token from training) with request's
|
||||
mergedOpts := make(map[string]string)
|
||||
for k, v := range job.ExtraOptions {
|
||||
mergedOpts[k] = v
|
||||
}
|
||||
for k, v := range req.ExtraOptions {
|
||||
mergedOpts[k] = v // request overrides job
|
||||
}
|
||||
|
||||
grpcReq := &pb.ExportModelRequest{
|
||||
CheckpointPath: req.CheckpointPath,
|
||||
OutputPath: outputPath,
|
||||
ExportFormat: req.ExportFormat,
|
||||
QuantizationMethod: req.QuantizationMethod,
|
||||
Model: req.Model,
|
||||
ExtraOptions: mergedOpts,
|
||||
}
|
||||
|
||||
s.setExportMessage(job, "Running model export (merging and converting — this may take a while)...")
|
||||
|
||||
result, err := backendModel.ExportModel(context.Background(), grpcReq)
|
||||
if err != nil {
|
||||
s.setExportFailed(job, fmt.Sprintf("export failed: %v", err))
|
||||
return
|
||||
}
|
||||
if !result.Success {
|
||||
s.setExportFailed(job, fmt.Sprintf("export failed: %s", result.Message))
|
||||
return
|
||||
}
|
||||
|
||||
s.setExportMessage(job, "Export complete, generating model configuration...")
|
||||
|
||||
// Auto-import: detect format and generate config
|
||||
cfg, err := importers.ImportLocalPath(outputPath, modelName)
|
||||
if err != nil {
|
||||
s.setExportFailed(job, fmt.Sprintf("model exported to %s but config generation failed: %v", outputPath, err))
|
||||
return
|
||||
}
|
||||
|
||||
cfg.Name = modelName
|
||||
|
||||
// If base model not detected from files, use the job's model field
|
||||
if cfg.Model == "" && job.Model != "" {
|
||||
cfg.Model = job.Model
|
||||
}
|
||||
|
||||
// Write YAML config
|
||||
yamlData, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
s.setExportFailed(job, fmt.Sprintf("failed to marshal config: %v", err))
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
|
||||
s.setExportFailed(job, fmt.Sprintf("failed to write config file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.setExportMessage(job, "Registering model with LocalAI...")
|
||||
|
||||
// Reload configs so the model is immediately available
|
||||
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
xlog.Warn("Failed to reload configs after export", "error", err)
|
||||
}
|
||||
if err := s.configLoader.Preload(modelsPath); err != nil {
|
||||
xlog.Warn("Failed to preload after export", "error", err)
|
||||
}
|
||||
|
||||
xlog.Info("Model exported and registered", "job_id", jobID, "model_name", modelName, "format", req.ExportFormat)
|
||||
|
||||
s.mu.Lock()
|
||||
job.ExportStatus = "completed"
|
||||
job.ExportModelName = modelName
|
||||
job.ExportMessage = ""
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
return modelName, nil
|
||||
}
|
||||
|
||||
// setExportMessage updates the export message and persists the job state.
|
||||
func (s *FineTuneService) setExportMessage(job *schema.FineTuneJob, msg string) {
|
||||
s.mu.Lock()
|
||||
job.ExportMessage = msg
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetExportedModelPath returns the path to the exported model directory and its name.
|
||||
func (s *FineTuneService) GetExportedModelPath(userID, jobID string) (string, string, error) {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return "", "", fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return "", "", fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if job.ExportStatus != "completed" {
|
||||
s.mu.Unlock()
|
||||
return "", "", fmt.Errorf("export not completed for job %s (status: %s)", jobID, job.ExportStatus)
|
||||
}
|
||||
exportModelName := job.ExportModelName
|
||||
s.mu.Unlock()
|
||||
|
||||
if exportModelName == "" {
|
||||
return "", "", fmt.Errorf("no exported model name for job %s", jobID)
|
||||
}
|
||||
|
||||
modelsPath := s.appConfig.SystemState.Model.ModelsPath
|
||||
modelDir := filepath.Join(modelsPath, exportModelName)
|
||||
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
return "", "", fmt.Errorf("exported model directory not found: %s", modelDir)
|
||||
}
|
||||
|
||||
return modelDir, exportModelName, nil
|
||||
}
|
||||
|
||||
// setExportFailed sets the export status to failed with a message.
|
||||
func (s *FineTuneService) setExportFailed(job *schema.FineTuneJob, message string) {
|
||||
xlog.Error("Export failed", "job_id", job.ID, "error", message)
|
||||
s.mu.Lock()
|
||||
job.ExportStatus = "failed"
|
||||
job.ExportMessage = message
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// UploadDataset handles dataset file upload and returns the local path.
|
||||
func (s *FineTuneService) UploadDataset(filename string, data []byte) (string, error) {
|
||||
uploadDir := filepath.Join(s.fineTuneBaseDir(), "datasets")
|
||||
if err := os.MkdirAll(uploadDir, 0750); err != nil {
|
||||
return "", fmt.Errorf("failed to create dataset directory: %w", err)
|
||||
}
|
||||
|
||||
filePath := filepath.Join(uploadDir, uuid.New().String()[:8]+"-"+filename)
|
||||
if err := os.WriteFile(filePath, data, 0640); err != nil {
|
||||
return "", fmt.Errorf("failed to write dataset: %w", err)
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// MarshalProgressEvent converts a progress event to JSON for SSE.
|
||||
func MarshalProgressEvent(event *schema.FineTuneProgressEvent) (string, error) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
226
docs/content/features/fine-tuning.md
Normal file
226
docs/content/features/fine-tuning.md
Normal file
@@ -0,0 +1,226 @@
|
||||
+++
|
||||
disableToc = false
|
||||
title = "Fine-Tuning"
|
||||
weight = 18
|
||||
url = '/features/fine-tuning/'
|
||||
+++
|
||||
|
||||
LocalAI supports fine-tuning LLMs directly through the API and Web UI. Fine-tuning is powered by pluggable backends that implement a generic gRPC interface, allowing support for different training frameworks and model types.
|
||||
|
||||
## Supported Backends
|
||||
|
||||
| Backend | Domain | GPU Required | Training Methods | Adapter Types |
|
||||
|---------|--------|-------------|-----------------|---------------|
|
||||
| **trl** | LLM fine-tuning | No (CPU or GPU) | SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO | LoRA, Full |
|
||||
|
||||
## Enabling Fine-Tuning
|
||||
|
||||
Fine-tuning is disabled by default. Enable it with:
|
||||
|
||||
```bash
|
||||
LOCALAI_ENABLE_FINETUNING=true local-ai
|
||||
```
|
||||
|
||||
When authentication is enabled, fine-tuning is a per-user feature (default OFF). Admins can enable it for specific users via the user management API.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Start a fine-tuning job
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/fine-tuning/jobs \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"backend": "trl",
|
||||
"training_method": "sft",
|
||||
"training_type": "lora",
|
||||
"dataset_source": "yahma/alpaca-cleaned",
|
||||
"num_epochs": 1,
|
||||
"batch_size": 2,
|
||||
"learning_rate": 0.0002,
|
||||
"adapter_rank": 16,
|
||||
"adapter_alpha": 16,
|
||||
"extra_options": {
|
||||
"max_seq_length": "512"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. Monitor progress (SSE stream)
|
||||
|
||||
```bash
|
||||
curl -N http://localhost:8080/api/fine-tuning/jobs/{job_id}/progress
|
||||
```
|
||||
|
||||
### 3. List checkpoints
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/api/fine-tuning/jobs/{job_id}/checkpoints
|
||||
```
|
||||
|
||||
### 4. Export model
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/fine-tuning/jobs/{job_id}/export \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"export_format": "gguf",
|
||||
"quantization_method": "q4_k_m",
|
||||
"output_path": "/models/my-finetuned-model"
|
||||
}'
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| `POST` | `/api/fine-tuning/jobs` | Start a fine-tuning job |
|
||||
| `GET` | `/api/fine-tuning/jobs` | List all jobs |
|
||||
| `GET` | `/api/fine-tuning/jobs/:id` | Get job details |
|
||||
| `DELETE` | `/api/fine-tuning/jobs/:id` | Stop a running job |
|
||||
| `GET` | `/api/fine-tuning/jobs/:id/progress` | SSE progress stream |
|
||||
| `GET` | `/api/fine-tuning/jobs/:id/checkpoints` | List checkpoints |
|
||||
| `POST` | `/api/fine-tuning/jobs/:id/export` | Export model |
|
||||
| `POST` | `/api/fine-tuning/datasets` | Upload dataset file |
|
||||
|
||||
### Job Request Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `model` | string | HuggingFace model ID or local path (required) |
|
||||
| `backend` | string | Backend name (default: `trl`) |
|
||||
| `training_method` | string | `sft`, `dpo`, `grpo`, `rloo`, `reward`, `kto`, `orpo` |
|
||||
| `training_type` | string | `lora` or `full` |
|
||||
| `dataset_source` | string | HuggingFace dataset ID or local file path (required) |
|
||||
| `adapter_rank` | int | LoRA rank (default: 16) |
|
||||
| `adapter_alpha` | int | LoRA alpha (default: 16) |
|
||||
| `num_epochs` | int | Number of training epochs (default: 3) |
|
||||
| `batch_size` | int | Per-device batch size (default: 2) |
|
||||
| `learning_rate` | float | Learning rate (default: 2e-4) |
|
||||
| `gradient_accumulation_steps` | int | Gradient accumulation (default: 4) |
|
||||
| `warmup_steps` | int | Warmup steps (default: 5) |
|
||||
| `optimizer` | string | `adamw_torch`, `adamw_8bit`, `sgd`, `adafactor`, `prodigy` |
|
||||
| `extra_options` | map | Backend-specific options (see below) |
|
||||
|
||||
### Backend-Specific Options (`extra_options`)
|
||||
|
||||
#### TRL
|
||||
|
||||
| Key | Description | Default |
|
||||
|-----|-------------|---------|
|
||||
| `max_seq_length` | Maximum sequence length | `512` |
|
||||
| `packing` | Enable sequence packing | `false` |
|
||||
| `trust_remote_code` | Trust remote code in model | `false` |
|
||||
| `load_in_4bit` | Enable 4-bit quantization (GPU only) | `false` |
|
||||
|
||||
#### DPO-specific (training_method=dpo)
|
||||
|
||||
| Key | Description | Default |
|
||||
|-----|-------------|---------|
|
||||
| `beta` | KL penalty coefficient | `0.1` |
|
||||
| `loss_type` | Loss type: `sigmoid`, `hinge`, `ipo` | `sigmoid` |
|
||||
| `max_length` | Maximum sequence length | `512` |
|
||||
|
||||
#### GRPO-specific (training_method=grpo)
|
||||
|
||||
| Key | Description | Default |
|
||||
|-----|-------------|---------|
|
||||
| `num_generations` | Number of generations per prompt | `4` |
|
||||
| `max_completion_length` | Max completion token length | `256` |
|
||||
|
||||
### GRPO Reward Functions
|
||||
|
||||
GRPO training requires reward functions to evaluate model completions. Specify them via the `reward_functions` field (a typed array) or via `extra_options["reward_funcs"]` (a JSON string).
|
||||
|
||||
#### Built-in Reward Functions
|
||||
|
||||
| Name | Description | Parameters |
|
||||
|------|-------------|-----------|
|
||||
| `format_reward` | Checks `<think>...</think>` then answer format (1.0/0.0) | — |
|
||||
| `reasoning_accuracy_reward` | Extracts `<answer>` content, compares to dataset's `answer` column | — |
|
||||
| `length_reward` | Score based on proximity to target length [0, 1] | `target_length` (default: 200) |
|
||||
| `xml_tag_reward` | Scores properly opened/closed `<think>` and `<answer>` tags | — |
|
||||
| `no_repetition_reward` | Penalizes n-gram repetition [0, 1] | — |
|
||||
| `code_execution_reward` | Checks Python code block syntax validity (1.0/0.0) | — |
|
||||
|
||||
#### Inline Custom Reward Functions
|
||||
|
||||
You can provide custom reward function code as a Python function body. The function receives `completions` (list of strings) and `**kwargs`, and must return `list[float]`.
|
||||
|
||||
**Security restrictions for inline code:**
|
||||
- Allowed builtins: `len`, `int`, `float`, `str`, `list`, `dict`, `range`, `enumerate`, `zip`, `map`, `filter`, `sorted`, `min`, `max`, `sum`, `abs`, `round`, `any`, `all`, `isinstance`, `print`, `True`, `False`, `None`
|
||||
- Available modules: `re`, `math`, `json`, `string`
|
||||
- Blocked: `open`, `__import__`, `exec`, `eval`, `compile`, `os`, `subprocess`, `getattr`, `setattr`, `delattr`, `globals`, `locals`
|
||||
- Functions are compiled and validated at job start (fail-fast on syntax errors)
|
||||
|
||||
#### Example API Request
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/fine-tuning/jobs \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "Qwen/Qwen2.5-1.5B-Instruct",
|
||||
"backend": "trl",
|
||||
"training_method": "grpo",
|
||||
"training_type": "lora",
|
||||
"dataset_source": "my-reasoning-dataset",
|
||||
"num_epochs": 1,
|
||||
"batch_size": 2,
|
||||
"learning_rate": 5e-6,
|
||||
"reward_functions": [
|
||||
{"type": "builtin", "name": "reasoning_accuracy_reward"},
|
||||
{"type": "builtin", "name": "format_reward"},
|
||||
{"type": "builtin", "name": "length_reward", "params": {"target_length": "200"}},
|
||||
{"type": "inline", "name": "think_presence", "code": "return [1.0 if \"<think>\" in c else 0.0 for c in completions]"}
|
||||
],
|
||||
"extra_options": {
|
||||
"num_generations": "4",
|
||||
"max_completion_length": "256"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### Export Formats
|
||||
|
||||
| Format | Description | Notes |
|
||||
|--------|-------------|-------|
|
||||
| `lora` | LoRA adapter files | Smallest, requires base model |
|
||||
| `merged_16bit` | Full model in 16-bit | Large but standalone |
|
||||
| `merged_4bit` | Full model in 4-bit | Smaller, standalone |
|
||||
| `gguf` | GGUF format | For llama.cpp, requires `quantization_method` |
|
||||
|
||||
### GGUF Quantization Methods
|
||||
|
||||
`q4_k_m`, `q5_k_m`, `q8_0`, `f16`, `q4_0`, `q5_0`
|
||||
|
||||
## Web UI
|
||||
|
||||
When fine-tuning is enabled, a "Fine-Tune" page appears in the sidebar under the Agents section. The UI provides:
|
||||
|
||||
1. **Job Configuration** — Select backend, model, training method, adapter type, and hyperparameters
|
||||
2. **Dataset Upload** — Upload local datasets or reference HuggingFace datasets
|
||||
3. **Training Monitor** — Real-time loss chart, progress bar, metrics display
|
||||
4. **Export** — Export trained models in various formats
|
||||
|
||||
## Dataset Formats
|
||||
|
||||
Datasets should follow standard HuggingFace formats:
|
||||
|
||||
- **SFT**: Alpaca format (`instruction`, `input`, `output` fields) or ChatML/ShareGPT
|
||||
- **DPO**: Preference pairs (`prompt`, `chosen`, `rejected` fields)
|
||||
- **GRPO**: Prompts with reward signals
|
||||
|
||||
Supported file formats: `.json`, `.jsonl`, `.csv`
|
||||
|
||||
## Architecture
|
||||
|
||||
Fine-tuning uses the same gRPC backend architecture as inference:
|
||||
|
||||
1. **Proto layer**: `FineTuneRequest`, `FineTuneProgress` (streaming), `StopFineTune`, `ListCheckpoints`, `ExportModel`
|
||||
2. **Python backends**: Each backend implements the gRPC interface with its specific training framework
|
||||
3. **Go service**: Manages job lifecycle, routes API requests to backends
|
||||
4. **REST API**: HTTP endpoints with SSE progress streaming
|
||||
5. **React UI**: Configuration form, real-time training monitor, export panel
|
||||
12
go.mod
12
go.mod
@@ -67,10 +67,18 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/chasefleming/elem-go v0.30.0 // indirect
|
||||
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||
github.com/gofiber/template v1.8.3 // indirect
|
||||
github.com/gofiber/template/html/v2 v2.1.3 // indirect
|
||||
github.com/gofiber/utils v1.1.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/spf13/cobra v1.10.2 // indirect
|
||||
github.com/spf13/pflag v1.0.9 // indirect
|
||||
github.com/stretchr/testify v1.11.1 // indirect
|
||||
github.com/tmc/langchaingo v0.1.14 // indirect
|
||||
)
|
||||
@@ -136,8 +144,8 @@ require (
|
||||
github.com/kevinburke/ssh_config v1.2.0 // indirect
|
||||
github.com/labstack/gommon v0.4.2 // indirect
|
||||
github.com/mschoch/smat v0.2.0 // indirect
|
||||
github.com/mudler/LocalAGI v0.0.0-20260319174513-43c65ec7e88a
|
||||
github.com/mudler/localrecall v0.5.9-0.20260319170742-933f68603f62 // indirect
|
||||
github.com/mudler/LocalAGI v0.0.0-20260321004723-b485b77037c4
|
||||
github.com/mudler/localrecall v0.5.9-0.20260321005011-810084e9369b // indirect
|
||||
github.com/mudler/skillserver v0.0.5
|
||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect
|
||||
|
||||
24
go.sum
24
go.sum
@@ -148,6 +148,8 @@ github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf h1:rLG0Y
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/chasefleming/elem-go v0.30.0 h1:BlhV1ekv1RbFiM8XZUQeln1Ikb4D+bu2eDO4agREvok=
|
||||
github.com/chasefleming/elem-go v0.30.0/go.mod h1:hz73qILBIKnTgOujnSMtEj20/epI+f6vg71RUilJAA4=
|
||||
github.com/chengxilo/virtualterm v1.0.4 h1:Z6IpERbRVlfB8WkOmtbHiDbBANU7cimRIof7mk9/PwM=
|
||||
github.com/chengxilo/virtualterm v1.0.4/go.mod h1:DyxxBZz/x1iqJjFxTFcr6/x+jSpqN0iwWCOK1q10rlY=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
@@ -177,6 +179,7 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/creachadair/mds v0.21.3 h1:RRgEAPIb52cU0q7UxGyN+13QlCVTZIL4slRr0cYYQfA=
|
||||
github.com/creachadair/mds v0.21.3/go.mod h1:1ltMWZd9yXhaHEoZwBialMaviWVUpRPvMwVP7saFAzM=
|
||||
github.com/creachadair/otp v0.5.0 h1:q3Th7CXm2zlmCdBjw5tEPFOj4oWJMnVL5HXlq0sNKS0=
|
||||
@@ -185,6 +188,8 @@ github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/cyphar/filepath-securejoin v0.5.1 h1:eYgfMq5yryL4fbWfkLpFFy2ukSELzaJOTaUTuh+oF48=
|
||||
github.com/cyphar/filepath-securejoin v0.5.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI=
|
||||
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 h1:flLYmnQFZNo04x2NPehMbf30m7Pli57xwZ0NFqR/hb0=
|
||||
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2/go.mod h1:NtWqRzAp/1tw+twkW8uuBenEVVYndEAZACWU3F3xdoQ=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
@@ -341,6 +346,12 @@ github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw=
|
||||
github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
|
||||
github.com/gofiber/template v1.8.3 h1:hzHdvMwMo/T2kouz2pPCA0zGiLCeMnoGsQZBTSYgZxc=
|
||||
github.com/gofiber/template v1.8.3/go.mod h1:bs/2n0pSNPOkRa5VJ8zTIvedcI/lEYxzV3+YPXdBvq8=
|
||||
github.com/gofiber/template/html/v2 v2.1.3 h1:n1LYBtmr9C0V/k/3qBblXyMxV5B0o/gpb6dFLp8ea+o=
|
||||
github.com/gofiber/template/html/v2 v2.1.3/go.mod h1:U5Fxgc5KpyujU9OqKzy6Kn6Qup6Tm7zdsISR+VpnHRE=
|
||||
github.com/gofiber/utils v1.1.0 h1:vdEBpn7AzIUJRhe+CiTOJdUcTg4Q9RK+pEa0KPbLdrM=
|
||||
github.com/gofiber/utils v1.1.0/go.mod h1:poZpsnhBykfnY1Mc0KeEa6mSHrS3dV0+oBWyeQmb2e0=
|
||||
github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw=
|
||||
github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0=
|
||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
@@ -445,6 +456,8 @@ github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc=
|
||||
github.com/huin/goupnp v1.3.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/ipfs/boxo v0.30.0 h1:7afsoxPGGqfoH7Dum/wOTGUB9M5fb8HyKPMlLfBvIEQ=
|
||||
github.com/ipfs/boxo v0.30.0/go.mod h1:BPqgGGyHB9rZZcPSzah2Dc9C+5Or3U1aQe7EH1H7370=
|
||||
github.com/ipfs/go-block-format v0.2.0 h1:ZqrkxBA2ICbDRbK8KJs/u0O3dlp6gmAuuXUJNiW1Ycs=
|
||||
@@ -666,6 +679,8 @@ github.com/mschoch/smat v0.2.0 h1:8imxQsjDm8yFEAVBe7azKmKSgzSkZXDuKkSq9374khM=
|
||||
github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw=
|
||||
github.com/mudler/LocalAGI v0.0.0-20260319174513-43c65ec7e88a h1:combrnE/eLPnUhqrYmtFmqEfR6x9xS+HoTFdnMozvik=
|
||||
github.com/mudler/LocalAGI v0.0.0-20260319174513-43c65ec7e88a/go.mod h1:AbBcAE9JqkexN4aG8rYQn5LzmzffWqcMvQ+Nlvin3WI=
|
||||
github.com/mudler/LocalAGI v0.0.0-20260321004723-b485b77037c4 h1:zWrAdAI/gwAPwXQAJuFLF8vvJdsxpxjKiBiC0EzhLOo=
|
||||
github.com/mudler/LocalAGI v0.0.0-20260321004723-b485b77037c4/go.mod h1:g+6CD5tP4a+rRW20CrMpE/JDazq5N4n4YDxIT7tT1mY=
|
||||
github.com/mudler/cogito v0.9.5-0.20260315222927-63abdec7189b h1:A74T2Lauvg61KodYqsjTYDY05kPLcW+efVZjd23dghU=
|
||||
github.com/mudler/cogito v0.9.5-0.20260315222927-63abdec7189b/go.mod h1:6sfja3lcu2nWRzEc0wwqGNu/eCG3EWgij+8s7xyUeQ4=
|
||||
github.com/mudler/edgevpn v0.31.1 h1:7qegiDWd0kAg6ljhNHxqvp8hbo/6BbzSdbb7/2WZfiY=
|
||||
@@ -676,6 +691,10 @@ github.com/mudler/go-processmanager v0.1.0 h1:fcSKgF9U/a1Z7KofAFeZnke5YseadCI5Gq
|
||||
github.com/mudler/go-processmanager v0.1.0/go.mod h1:h6kmHUZeafr+k5hRYpGLMzJFH4hItHffgpRo2QIkP+o=
|
||||
github.com/mudler/localrecall v0.5.9-0.20260319170742-933f68603f62 h1:KVTEukvLlQXKZx1C1ZLru+ahaiECLF+7v2caK8vauJ0=
|
||||
github.com/mudler/localrecall v0.5.9-0.20260319170742-933f68603f62/go.mod h1:/d2bG9H8G/HzsnXTTQl2bOD+ui74XwpeiSDJ+2gdkGc=
|
||||
github.com/mudler/localrecall v0.5.9-0.20260321003356-422f3b1fff45 h1:+zTrbYk70wHrtvpsO2k7gMPvHYnWYCnXNxAtMex+7yg=
|
||||
github.com/mudler/localrecall v0.5.9-0.20260321003356-422f3b1fff45/go.mod h1:/d2bG9H8G/HzsnXTTQl2bOD+ui74XwpeiSDJ+2gdkGc=
|
||||
github.com/mudler/localrecall v0.5.9-0.20260321005011-810084e9369b h1:XeAnOEOOSKMfS5XNGpRTltQgjKCinho0V4uAhrgxN7Q=
|
||||
github.com/mudler/localrecall v0.5.9-0.20260321005011-810084e9369b/go.mod h1:xuPtgL9zUyiQLmspYzO3kaboYrGbWmwi8BQPt1aCAcs=
|
||||
github.com/mudler/memory v0.0.0-20251216220809-d1256471a6c2 h1:+WHsL/j6EWOMUiMVIOJNKOwSKiQt/qDPc9fePCf87fA=
|
||||
github.com/mudler/memory v0.0.0-20251216220809-d1256471a6c2/go.mod h1:EA8Ashhd56o32qN7ouPKFSRUs/Z+LrRCF4v6R2Oarm8=
|
||||
github.com/mudler/skillserver v0.0.5 h1:t6HPpeSX8kEP7B8F5GXoQUam5VEYNmJuG6oy2/vdTu8=
|
||||
@@ -855,6 +874,7 @@ github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR
|
||||
github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww=
|
||||
github.com/russross/blackfriday v1.6.0/go.mod h1:ti0ldHuxg49ri4ksnFxlkCfN+hvslNlmVHqNRXXJNAY=
|
||||
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
|
||||
github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU=
|
||||
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
|
||||
@@ -928,6 +948,10 @@ github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0b
|
||||
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
||||
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
|
||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=
|
||||
|
||||
@@ -63,4 +63,11 @@ type Backend interface {
|
||||
AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error)
|
||||
|
||||
ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error)
|
||||
|
||||
// Fine-tuning
|
||||
StartFineTune(ctx context.Context, in *pb.FineTuneRequest, opts ...grpc.CallOption) (*pb.FineTuneJobResult, error)
|
||||
FineTuneProgress(ctx context.Context, in *pb.FineTuneProgressRequest, f func(update *pb.FineTuneProgressUpdate), opts ...grpc.CallOption) error
|
||||
StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error)
|
||||
ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||
}
|
||||
|
||||
@@ -120,6 +120,26 @@ func (llm *Base) AudioDecode(*pb.AudioDecodeRequest) (*pb.AudioDecodeResult, err
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) StartFineTune(*pb.FineTuneRequest) (*pb.FineTuneJobResult, error) {
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) FineTuneProgress(*pb.FineTuneProgressRequest, chan *pb.FineTuneProgressUpdate) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) StopFineTune(*pb.FineTuneStopRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) ListCheckpoints(*pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error) {
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) ExportModel(*pb.ExportModelRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func memoryUsage() *pb.MemoryUsageData {
|
||||
mud := pb.MemoryUsageData{
|
||||
Breakdown: make(map[string]uint64),
|
||||
|
||||
@@ -632,6 +632,142 @@ func (c *Client) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opt
|
||||
return client.AudioDecode(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) StartFineTune(ctx context.Context, in *pb.FineTuneRequest, opts ...grpc.CallOption) (*pb.FineTuneJobResult, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.StartFineTune(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) FineTuneProgress(ctx context.Context, in *pb.FineTuneProgressRequest, f func(update *pb.FineTuneProgressUpdate), opts ...grpc.CallOption) error {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
|
||||
stream, err := client.FineTuneProgress(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
update, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
f(update)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.StopFineTune(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.ListCheckpoints(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
}
|
||||
c.setBusy(true)
|
||||
defer c.setBusy(false)
|
||||
c.wdMark()
|
||||
defer c.wdUnMark()
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewBackendClient(conn)
|
||||
return client.ExportModel(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
|
||||
@@ -123,6 +123,68 @@ func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsReques
|
||||
return e.s.GetMetrics(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) StartFineTune(ctx context.Context, in *pb.FineTuneRequest, opts ...grpc.CallOption) (*pb.FineTuneJobResult, error) {
|
||||
return e.s.StartFineTune(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) FineTuneProgress(ctx context.Context, in *pb.FineTuneProgressRequest, f func(update *pb.FineTuneProgressUpdate), opts ...grpc.CallOption) error {
|
||||
bs := &embedBackendFineTuneProgressStream{
|
||||
ctx: ctx,
|
||||
fn: f,
|
||||
}
|
||||
return e.s.FineTuneProgress(in, bs)
|
||||
}
|
||||
|
||||
func (e *embedBackend) StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.StopFineTune(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error) {
|
||||
return e.s.ListCheckpoints(ctx, in)
|
||||
}
|
||||
|
||||
func (e *embedBackend) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||
return e.s.ExportModel(ctx, in)
|
||||
}
|
||||
|
||||
var _ pb.Backend_FineTuneProgressServer = new(embedBackendFineTuneProgressStream)
|
||||
|
||||
type embedBackendFineTuneProgressStream struct {
|
||||
ctx context.Context
|
||||
fn func(update *pb.FineTuneProgressUpdate)
|
||||
}
|
||||
|
||||
func (e *embedBackendFineTuneProgressStream) Send(update *pb.FineTuneProgressUpdate) error {
|
||||
e.fn(update)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendFineTuneProgressStream) SetHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendFineTuneProgressStream) SendHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendFineTuneProgressStream) SetTrailer(md metadata.MD) {
|
||||
}
|
||||
|
||||
func (e *embedBackendFineTuneProgressStream) Context() context.Context {
|
||||
return e.ctx
|
||||
}
|
||||
|
||||
func (e *embedBackendFineTuneProgressStream) SendMsg(m any) error {
|
||||
if x, ok := m.(*pb.FineTuneProgressUpdate); ok {
|
||||
return e.Send(x)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *embedBackendFineTuneProgressStream) RecvMsg(m any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type embedBackendServerStream struct {
|
||||
ctx context.Context
|
||||
fn func(reply *pb.Reply)
|
||||
|
||||
@@ -35,6 +35,13 @@ type AIModel interface {
|
||||
AudioDecode(*pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error)
|
||||
|
||||
ModelMetadata(*pb.ModelOptions) (*pb.ModelMetadataResponse, error)
|
||||
|
||||
// Fine-tuning
|
||||
StartFineTune(*pb.FineTuneRequest) (*pb.FineTuneJobResult, error)
|
||||
FineTuneProgress(*pb.FineTuneProgressRequest, chan *pb.FineTuneProgressUpdate) error
|
||||
StopFineTune(*pb.FineTuneStopRequest) error
|
||||
ListCheckpoints(*pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error)
|
||||
ExportModel(*pb.ExportModelRequest) error
|
||||
}
|
||||
|
||||
func newReply(s string) *pb.Reply {
|
||||
|
||||
@@ -308,6 +308,75 @@ func (s *server) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest) (*p
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *server) StartFineTune(ctx context.Context, in *pb.FineTuneRequest) (*pb.FineTuneJobResult, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
res, err := s.llm.StartFineTune(in)
|
||||
if err != nil {
|
||||
return &pb.FineTuneJobResult{Success: false, Message: fmt.Sprintf("Error starting fine-tune: %s", err.Error())}, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *server) FineTuneProgress(in *pb.FineTuneProgressRequest, stream pb.Backend_FineTuneProgressServer) error {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
updateChan := make(chan *pb.FineTuneProgressUpdate)
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
for update := range updateChan {
|
||||
stream.Send(update)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
err := s.llm.FineTuneProgress(in, updateChan)
|
||||
<-done
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *server) StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest) (*pb.Result, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
err := s.llm.StopFineTune(in)
|
||||
if err != nil {
|
||||
return &pb.Result{Message: fmt.Sprintf("Error stopping fine-tune: %s", err.Error()), Success: false}, err
|
||||
}
|
||||
return &pb.Result{Message: "Fine-tune stopped", Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *server) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
res, err := s.llm.ListCheckpoints(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *server) ExportModel(ctx context.Context, in *pb.ExportModelRequest) (*pb.Result, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
err := s.llm.ExportModel(in)
|
||||
if err != nil {
|
||||
return &pb.Result{Message: fmt.Sprintf("Error exporting model: %s", err.Error()), Success: false}, err
|
||||
}
|
||||
return &pb.Result{Message: "Model exported", Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *server) ModelMetadata(ctx context.Context, in *pb.ModelOptions) (*pb.ModelMetadataResponse, error) {
|
||||
if s.llm.Locking() {
|
||||
s.llm.Lock()
|
||||
|
||||
Reference in New Issue
Block a user