mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-22 15:49:12 -04:00
Compare commits
6 Commits
feat/recon
...
feat/fine-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8997ff6042 | ||
|
|
f1223b45b2 | ||
|
|
fa8b1a8673 | ||
|
|
3451dbdccd | ||
|
|
7b8afc9609 | ||
|
|
ae4b758a5a |
141
.agents/debugging-backends.md
Normal file
141
.agents/debugging-backends.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Debugging and Rebuilding Backends
|
||||||
|
|
||||||
|
When a backend fails at runtime (e.g. a gRPC method error, a Python import error, or a dependency conflict), use this guide to diagnose, fix, and rebuild.
|
||||||
|
|
||||||
|
## Architecture Overview
|
||||||
|
|
||||||
|
- **Source directory**: `backend/python/<name>/` (or `backend/go/<name>/`, `backend/cpp/<name>/`)
|
||||||
|
- **Installed directory**: `backends/<name>/` — this is what LocalAI actually runs. It is populated by `make backends/<name>` which builds a Docker image, exports it, and installs it via `local-ai backends install`.
|
||||||
|
- **Virtual environment**: `backends/<name>/venv/` — the installed Python venv (for Python backends). The Python binary is at `backends/<name>/venv/bin/python`.
|
||||||
|
|
||||||
|
Editing files in `backend/python/<name>/` does **not** affect the running backend until you rebuild with `make backends/<name>`.
|
||||||
|
|
||||||
|
## Diagnosing Failures
|
||||||
|
|
||||||
|
### 1. Check the logs
|
||||||
|
|
||||||
|
Backend gRPC processes log to LocalAI's stdout/stderr. Look for lines tagged with the backend's model ID:
|
||||||
|
|
||||||
|
```
|
||||||
|
GRPC stderr id="trl-finetune-127.0.0.1:37335" line="..."
|
||||||
|
```
|
||||||
|
|
||||||
|
Common error patterns:
|
||||||
|
- **"Method not implemented"** — the backend is missing a gRPC method that the Go side calls. The model loader (`pkg/model/initializers.go`) always calls `LoadModel` after `Health`; fine-tuning backends must implement it even as a no-op stub.
|
||||||
|
- **Python import errors / `AttributeError`** — usually a dependency version mismatch (e.g. `pyarrow` removing `PyExtensionType`).
|
||||||
|
- **"failed to load backend"** — the gRPC process crashed or never started. Check stderr lines for the traceback.
|
||||||
|
|
||||||
|
### 2. Test the Python environment directly
|
||||||
|
|
||||||
|
You can run the installed venv's Python to check imports without starting the full server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
backends/<name>/venv/bin/python -c "import datasets; print(datasets.__version__)"
|
||||||
|
```
|
||||||
|
|
||||||
|
If `pip` is missing from the venv, bootstrap it:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
backends/<name>/venv/bin/python -m ensurepip
|
||||||
|
```
|
||||||
|
|
||||||
|
Then use `backends/<name>/venv/bin/python -m pip install ...` to test fixes in the installed venv before committing them to the source requirements.
|
||||||
|
|
||||||
|
### 3. Check upstream dependency constraints
|
||||||
|
|
||||||
|
When you hit a dependency conflict, check what the main library expects. For example, TRL's upstream `requirements.txt`:
|
||||||
|
|
||||||
|
```
|
||||||
|
https://github.com/huggingface/trl/blob/main/requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Pin minimum versions in the backend's requirements files to match upstream.
|
||||||
|
|
||||||
|
## Common Fixes
|
||||||
|
|
||||||
|
### Missing gRPC methods
|
||||||
|
|
||||||
|
If the Go side calls a method the backend doesn't implement (e.g. `LoadModel`), add a no-op stub in `backend.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def LoadModel(self, request, context):
|
||||||
|
"""No-op — actual loading happens elsewhere."""
|
||||||
|
return backend_pb2.Result(success=True, message="OK")
|
||||||
|
```
|
||||||
|
|
||||||
|
The gRPC contract requires `LoadModel` to succeed for the model loader to return a usable client, even if the backend doesn't need upfront model loading.
|
||||||
|
|
||||||
|
### Dependency version conflicts
|
||||||
|
|
||||||
|
Python backends often break when a transitive dependency releases a breaking change (e.g. `pyarrow` removing `PyExtensionType`). Steps:
|
||||||
|
|
||||||
|
1. Identify the broken import in the logs
|
||||||
|
2. Test in the installed venv: `backends/<name>/venv/bin/python -c "import <module>"`
|
||||||
|
3. Check upstream requirements for version constraints
|
||||||
|
4. Update **all** requirements files in `backend/python/<name>/`:
|
||||||
|
- `requirements.txt` — base deps (grpcio, protobuf)
|
||||||
|
- `requirements-cpu.txt` — CPU-specific (includes PyTorch CPU index)
|
||||||
|
- `requirements-cublas12.txt` — CUDA 12
|
||||||
|
- `requirements-cublas13.txt` — CUDA 13
|
||||||
|
5. Rebuild: `make backends/<name>`
|
||||||
|
|
||||||
|
### PyTorch index conflicts (uv resolver)
|
||||||
|
|
||||||
|
The Docker build uses `uv` for pip installs. When `--extra-index-url` points to the PyTorch wheel index, `uv` may refuse to fetch packages like `requests` from PyPI if it finds a different version on the PyTorch index first. Fix this by adding `--index-strategy=unsafe-first-match` to `install.sh`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||||
|
installRequirements
|
||||||
|
```
|
||||||
|
|
||||||
|
Most Python backends already do this — check `backend/python/transformers/install.sh` or similar for reference.
|
||||||
|
|
||||||
|
## Rebuilding
|
||||||
|
|
||||||
|
### Rebuild a single backend
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make backends/<name>
|
||||||
|
```
|
||||||
|
|
||||||
|
This runs the Docker build (`Dockerfile.python`), exports the image to `backend-images/<name>.tar`, and installs it into `backends/<name>/`. It also rebuilds the `local-ai` Go binary (without extra tags).
|
||||||
|
|
||||||
|
**Important**: If you were previously running with `GO_TAGS=auth`, the `make backends/<name>` step will overwrite your binary without that tag. Rebuild the Go binary afterward:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
GO_TAGS=auth make build
|
||||||
|
```
|
||||||
|
|
||||||
|
### Rebuild and restart
|
||||||
|
|
||||||
|
After rebuilding a backend, you must restart LocalAI for it to pick up the new backend files. The backend gRPC process is spawned on demand when the model is first loaded.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Kill existing process
|
||||||
|
kill <pid>
|
||||||
|
|
||||||
|
# Restart
|
||||||
|
./local-ai run --debug [your flags]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quick iteration (skip Docker rebuild)
|
||||||
|
|
||||||
|
For fast iteration on a Python backend's `backend.py` without a full Docker rebuild, you can edit the installed copy directly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Edit the installed copy
|
||||||
|
vim backends/<name>/backend.py
|
||||||
|
|
||||||
|
# Restart LocalAI to respawn the gRPC process
|
||||||
|
```
|
||||||
|
|
||||||
|
This is useful for testing but **does not persist** — the next `make backends/<name>` will overwrite it. Always commit fixes to the source in `backend/python/<name>/`.
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
After fixing and rebuilding:
|
||||||
|
|
||||||
|
1. Start LocalAI and confirm the backend registers: look for `Registering backend name="<name>"` in the logs
|
||||||
|
2. Trigger the operation that failed (e.g. start a fine-tuning job)
|
||||||
|
3. Watch the GRPC stderr/stdout lines for the backend's model ID
|
||||||
|
4. Confirm no errors in the traceback
|
||||||
39
.github/workflows/backend.yml
vendored
39
.github/workflows/backend.yml
vendored
@@ -118,6 +118,19 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./"
|
context: "./"
|
||||||
ubuntu-version: '2404'
|
ubuntu-version: '2404'
|
||||||
|
- build-type: ''
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-cpu-trl'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'true'
|
||||||
|
backend: "trl"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
@@ -366,6 +379,19 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./"
|
context: "./"
|
||||||
ubuntu-version: '2404'
|
ubuntu-version: '2404'
|
||||||
|
- build-type: 'cublas'
|
||||||
|
cuda-major-version: "12"
|
||||||
|
cuda-minor-version: "8"
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-nvidia-cuda-12-trl'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "trl"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "12"
|
cuda-major-version: "12"
|
||||||
cuda-minor-version: "8"
|
cuda-minor-version: "8"
|
||||||
@@ -757,6 +783,19 @@ jobs:
|
|||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./"
|
context: "./"
|
||||||
ubuntu-version: '2404'
|
ubuntu-version: '2404'
|
||||||
|
- build-type: 'cublas'
|
||||||
|
cuda-major-version: "13"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-nvidia-cuda-13-trl'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:24.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "trl"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./"
|
||||||
|
ubuntu-version: '2404'
|
||||||
- build-type: 'l4t'
|
- build-type: 'l4t'
|
||||||
cuda-major-version: "13"
|
cuda-major-version: "13"
|
||||||
cuda-minor-version: "0"
|
cuda-minor-version: "0"
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ This file is an index to detailed topic guides in the `.agents/` directory. Read
|
|||||||
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
|
| [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing |
|
||||||
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
| [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI |
|
||||||
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
| [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control |
|
||||||
|
| [.agents/debugging-backends.md](.agents/debugging-backends.md) | Debugging runtime backend failures, dependency conflicts, rebuilding backends |
|
||||||
|
|
||||||
## Quick Reference
|
## Quick Reference
|
||||||
|
|
||||||
|
|||||||
8
Makefile
8
Makefile
@@ -1,5 +1,5 @@
|
|||||||
# Disable parallel execution for backend builds
|
# Disable parallel execution for backend builds
|
||||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus
|
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus backends/trl
|
||||||
|
|
||||||
GOCMD=go
|
GOCMD=go
|
||||||
GOTEST=$(GOCMD) test
|
GOTEST=$(GOCMD) test
|
||||||
@@ -421,6 +421,7 @@ prepare-test-extra: protogen-python
|
|||||||
$(MAKE) -C backend/python/voxcpm
|
$(MAKE) -C backend/python/voxcpm
|
||||||
$(MAKE) -C backend/python/whisperx
|
$(MAKE) -C backend/python/whisperx
|
||||||
$(MAKE) -C backend/python/ace-step
|
$(MAKE) -C backend/python/ace-step
|
||||||
|
$(MAKE) -C backend/python/trl
|
||||||
|
|
||||||
test-extra: prepare-test-extra
|
test-extra: prepare-test-extra
|
||||||
$(MAKE) -C backend/python/transformers test
|
$(MAKE) -C backend/python/transformers test
|
||||||
@@ -440,6 +441,7 @@ test-extra: prepare-test-extra
|
|||||||
$(MAKE) -C backend/python/voxcpm test
|
$(MAKE) -C backend/python/voxcpm test
|
||||||
$(MAKE) -C backend/python/whisperx test
|
$(MAKE) -C backend/python/whisperx test
|
||||||
$(MAKE) -C backend/python/ace-step test
|
$(MAKE) -C backend/python/ace-step test
|
||||||
|
$(MAKE) -C backend/python/trl test
|
||||||
|
|
||||||
DOCKER_IMAGE?=local-ai
|
DOCKER_IMAGE?=local-ai
|
||||||
IMAGE_TYPE?=core
|
IMAGE_TYPE?=core
|
||||||
@@ -572,6 +574,7 @@ BACKEND_VOXCPM = voxcpm|python|.|false|true
|
|||||||
BACKEND_WHISPERX = whisperx|python|.|false|true
|
BACKEND_WHISPERX = whisperx|python|.|false|true
|
||||||
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
||||||
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||||
|
BACKEND_TRL = trl|python|.|false|true
|
||||||
|
|
||||||
# Helper function to build docker image for a backend
|
# Helper function to build docker image for a backend
|
||||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||||
@@ -629,12 +632,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
|
|||||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP)))
|
||||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||||
|
$(eval $(call generate-docker-build-target,$(BACKEND_TRL)))
|
||||||
|
|
||||||
# Pattern rule for docker-save targets
|
# Pattern rule for docker-save targets
|
||||||
docker-save-%: backend-images
|
docker-save-%: backend-images
|
||||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||||
|
|
||||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed
|
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed docker-build-trl
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
### Mock Backend for E2E Tests
|
### Mock Backend for E2E Tests
|
||||||
|
|||||||
@@ -39,6 +39,13 @@ service Backend {
|
|||||||
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
|
rpc AudioDecode(AudioDecodeRequest) returns (AudioDecodeResult) {}
|
||||||
|
|
||||||
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
|
rpc ModelMetadata(ModelOptions) returns (ModelMetadataResponse) {}
|
||||||
|
|
||||||
|
// Fine-tuning RPCs
|
||||||
|
rpc StartFineTune(FineTuneRequest) returns (FineTuneJobResult) {}
|
||||||
|
rpc FineTuneProgress(FineTuneProgressRequest) returns (stream FineTuneProgressUpdate) {}
|
||||||
|
rpc StopFineTune(FineTuneStopRequest) returns (Result) {}
|
||||||
|
rpc ListCheckpoints(ListCheckpointsRequest) returns (ListCheckpointsResponse) {}
|
||||||
|
rpc ExportModel(ExportModelRequest) returns (Result) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define the empty request
|
// Define the empty request
|
||||||
@@ -528,3 +535,105 @@ message ModelMetadataResponse {
|
|||||||
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
||||||
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
|
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fine-tuning messages
|
||||||
|
|
||||||
|
message FineTuneRequest {
|
||||||
|
// Model identification
|
||||||
|
string model = 1; // HF model name or local path
|
||||||
|
string training_type = 2; // "lora", "loha", "lokr", "full" — what parameters to train
|
||||||
|
string training_method = 3; // "sft", "dpo", "grpo", "rloo", "reward", "kto", "orpo", "network_training"
|
||||||
|
|
||||||
|
// Adapter config (universal across LoRA/LoHa/LoKr for LLM + diffusion)
|
||||||
|
int32 adapter_rank = 10; // LoRA rank (r), default 16
|
||||||
|
int32 adapter_alpha = 11; // scaling factor, default 16
|
||||||
|
float adapter_dropout = 12; // default 0.0
|
||||||
|
repeated string target_modules = 13; // layer names to adapt
|
||||||
|
|
||||||
|
// Universal training hyperparameters
|
||||||
|
float learning_rate = 20; // default 2e-4
|
||||||
|
int32 num_epochs = 21; // default 3
|
||||||
|
int32 batch_size = 22; // default 2
|
||||||
|
int32 gradient_accumulation_steps = 23; // default 4
|
||||||
|
int32 warmup_steps = 24; // default 5
|
||||||
|
int32 max_steps = 25; // 0 = use epochs
|
||||||
|
int32 save_steps = 26; // 0 = only save final
|
||||||
|
float weight_decay = 27; // default 0.01
|
||||||
|
bool gradient_checkpointing = 28;
|
||||||
|
string optimizer = 29; // adamw_8bit, adamw, sgd, adafactor, prodigy
|
||||||
|
int32 seed = 30; // default 3407
|
||||||
|
string mixed_precision = 31; // fp16, bf16, fp8, no
|
||||||
|
|
||||||
|
// Dataset
|
||||||
|
string dataset_source = 40; // HF dataset ID, local file/dir path
|
||||||
|
string dataset_split = 41; // train, test, etc.
|
||||||
|
|
||||||
|
// Output
|
||||||
|
string output_dir = 50;
|
||||||
|
string job_id = 51; // client-assigned or auto-generated
|
||||||
|
|
||||||
|
// Resume training from a checkpoint
|
||||||
|
string resume_from_checkpoint = 55; // path to checkpoint dir to resume from
|
||||||
|
|
||||||
|
// Backend-specific AND method-specific extensibility
|
||||||
|
map<string, string> extra_options = 60;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FineTuneJobResult {
|
||||||
|
string job_id = 1;
|
||||||
|
bool success = 2;
|
||||||
|
string message = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FineTuneProgressRequest {
|
||||||
|
string job_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FineTuneProgressUpdate {
|
||||||
|
string job_id = 1;
|
||||||
|
int32 current_step = 2;
|
||||||
|
int32 total_steps = 3;
|
||||||
|
float current_epoch = 4;
|
||||||
|
float total_epochs = 5;
|
||||||
|
float loss = 6;
|
||||||
|
float learning_rate = 7;
|
||||||
|
float grad_norm = 8;
|
||||||
|
float eval_loss = 9;
|
||||||
|
float eta_seconds = 10;
|
||||||
|
float progress_percent = 11;
|
||||||
|
string status = 12; // queued, caching, loading_model, loading_dataset, training, saving, completed, failed, stopped
|
||||||
|
string message = 13;
|
||||||
|
string checkpoint_path = 14; // set when a checkpoint is saved
|
||||||
|
string sample_path = 15; // set when a sample is generated (video/image backends)
|
||||||
|
map<string, float> extra_metrics = 16; // method-specific metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
message FineTuneStopRequest {
|
||||||
|
string job_id = 1;
|
||||||
|
bool save_checkpoint = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListCheckpointsRequest {
|
||||||
|
string output_dir = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListCheckpointsResponse {
|
||||||
|
repeated CheckpointInfo checkpoints = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CheckpointInfo {
|
||||||
|
string path = 1;
|
||||||
|
int32 step = 2;
|
||||||
|
float epoch = 3;
|
||||||
|
float loss = 4;
|
||||||
|
string created_at = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExportModelRequest {
|
||||||
|
string checkpoint_path = 1;
|
||||||
|
string output_path = 2;
|
||||||
|
string export_format = 3; // lora, loha, lokr, merged_16bit, merged_4bit, gguf, diffusers
|
||||||
|
string quantization_method = 4; // for GGUF: q4_k_m, q5_k_m, q8_0, f16, etc.
|
||||||
|
string model = 5; // base model name (for merge operations)
|
||||||
|
map<string, string> extra_options = 6;
|
||||||
|
}
|
||||||
|
|||||||
@@ -3029,3 +3029,54 @@
|
|||||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral"
|
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral"
|
||||||
mirrors:
|
mirrors:
|
||||||
- localai/localai-backends:master-metal-darwin-arm64-voxtral
|
- localai/localai-backends:master-metal-darwin-arm64-voxtral
|
||||||
|
- &trl
|
||||||
|
name: "trl"
|
||||||
|
alias: "trl"
|
||||||
|
license: apache-2.0
|
||||||
|
description: |
|
||||||
|
HuggingFace TRL fine-tuning backend. Supports SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO training methods.
|
||||||
|
Works on CPU and GPU.
|
||||||
|
urls:
|
||||||
|
- https://github.com/huggingface/trl
|
||||||
|
tags:
|
||||||
|
- fine-tuning
|
||||||
|
- LLM
|
||||||
|
- CPU
|
||||||
|
- GPU
|
||||||
|
- CUDA
|
||||||
|
capabilities:
|
||||||
|
default: "cpu-trl"
|
||||||
|
nvidia: "cuda12-trl"
|
||||||
|
nvidia-cuda-12: "cuda12-trl"
|
||||||
|
nvidia-cuda-13: "cuda13-trl"
|
||||||
|
## TRL backend images
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cpu-trl"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cpu-trl
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cpu-trl-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-cpu-trl
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cuda12-trl"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda12-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cublas-cuda12-trl
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cuda12-trl-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda12-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-cublas-cuda12-trl
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cuda13-trl"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cublas-cuda13-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cublas-cuda13-trl
|
||||||
|
- !!merge <<: *trl
|
||||||
|
name: "cuda13-trl-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-cublas-cuda13-trl"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-cublas-cuda13-trl
|
||||||
|
|||||||
26
backend/python/trl/Makefile
Normal file
26
backend/python/trl/Makefile
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export)
|
||||||
|
LLAMA_CPP_CONVERT_VERSION ?= master
|
||||||
|
|
||||||
|
.PHONY: trl
|
||||||
|
trl:
|
||||||
|
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
|
||||||
|
|
||||||
|
.PHONY: run
|
||||||
|
run: trl
|
||||||
|
@echo "Running trl..."
|
||||||
|
bash run.sh
|
||||||
|
@echo "trl run."
|
||||||
|
|
||||||
|
.PHONY: test
|
||||||
|
test: trl
|
||||||
|
@echo "Testing trl..."
|
||||||
|
bash test.sh
|
||||||
|
@echo "trl tested."
|
||||||
|
|
||||||
|
.PHONY: protogen-clean
|
||||||
|
protogen-clean:
|
||||||
|
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||||
|
|
||||||
|
.PHONY: clean
|
||||||
|
clean: protogen-clean
|
||||||
|
rm -rf venv __pycache__
|
||||||
860
backend/python/trl/backend.py
Normal file
860
backend/python/trl/backend.py
Normal file
@@ -0,0 +1,860 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
TRL fine-tuning backend for LocalAI.
|
||||||
|
|
||||||
|
Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO)
|
||||||
|
using standard HuggingFace transformers + PEFT. Works on both CPU and GPU.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressCallback:
|
||||||
|
"""HuggingFace TrainerCallback that pushes progress updates to a queue."""
|
||||||
|
|
||||||
|
def __init__(self, job_id, progress_queue, total_epochs):
|
||||||
|
self.job_id = job_id
|
||||||
|
self.progress_queue = progress_queue
|
||||||
|
self.total_epochs = total_epochs
|
||||||
|
|
||||||
|
def get_callback(self):
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
parent = self
|
||||||
|
|
||||||
|
class _Callback(TrainerCallback):
|
||||||
|
def __init__(self):
|
||||||
|
self._train_start_time = None
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, **kwargs):
|
||||||
|
self._train_start_time = time.time()
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
if logs is None:
|
||||||
|
return
|
||||||
|
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||||
|
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||||
|
eta = 0.0
|
||||||
|
if state.global_step > 0 and total_steps > 0 and self._train_start_time:
|
||||||
|
elapsed = time.time() - self._train_start_time
|
||||||
|
remaining_steps = total_steps - state.global_step
|
||||||
|
if state.global_step > 0:
|
||||||
|
eta = remaining_steps * (elapsed / state.global_step)
|
||||||
|
|
||||||
|
extra_metrics = {}
|
||||||
|
for k, v in logs.items():
|
||||||
|
if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
|
||||||
|
extra_metrics[k] = float(v)
|
||||||
|
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=parent.job_id,
|
||||||
|
current_step=state.global_step,
|
||||||
|
total_steps=total_steps,
|
||||||
|
current_epoch=float(logs.get('epoch', 0)),
|
||||||
|
total_epochs=float(parent.total_epochs),
|
||||||
|
loss=float(logs.get('loss', 0)),
|
||||||
|
learning_rate=float(logs.get('learning_rate', 0)),
|
||||||
|
grad_norm=float(logs.get('grad_norm', 0)),
|
||||||
|
eval_loss=float(logs.get('eval_loss', 0)),
|
||||||
|
eta_seconds=float(eta),
|
||||||
|
progress_percent=float(progress),
|
||||||
|
status="training",
|
||||||
|
extra_metrics=extra_metrics,
|
||||||
|
)
|
||||||
|
parent.progress_queue.put(update)
|
||||||
|
|
||||||
|
def on_prediction_step(self, args, state, control, **kwargs):
|
||||||
|
"""Send periodic updates during evaluation so the UI doesn't freeze."""
|
||||||
|
if not hasattr(self, '_eval_update_counter'):
|
||||||
|
self._eval_update_counter = 0
|
||||||
|
self._eval_update_counter += 1
|
||||||
|
# Throttle: send an update every 10 prediction steps
|
||||||
|
if self._eval_update_counter % 10 != 0:
|
||||||
|
return
|
||||||
|
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||||
|
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=parent.job_id,
|
||||||
|
current_step=state.global_step,
|
||||||
|
total_steps=total_steps,
|
||||||
|
current_epoch=float(state.epoch or 0),
|
||||||
|
total_epochs=float(parent.total_epochs),
|
||||||
|
progress_percent=float(progress),
|
||||||
|
status="training",
|
||||||
|
message=f"Evaluating... (batch {self._eval_update_counter})",
|
||||||
|
)
|
||||||
|
parent.progress_queue.put(update)
|
||||||
|
|
||||||
|
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||||
|
"""Report eval results once evaluation is done."""
|
||||||
|
# Reset prediction counter for next eval round
|
||||||
|
self._eval_update_counter = 0
|
||||||
|
|
||||||
|
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||||
|
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||||
|
|
||||||
|
eval_loss = 0.0
|
||||||
|
extra_metrics = {}
|
||||||
|
if metrics:
|
||||||
|
eval_loss = float(metrics.get('eval_loss', 0))
|
||||||
|
for k, v in metrics.items():
|
||||||
|
if isinstance(v, (int, float)) and k not in ('eval_loss', 'epoch'):
|
||||||
|
extra_metrics[k] = float(v)
|
||||||
|
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=parent.job_id,
|
||||||
|
current_step=state.global_step,
|
||||||
|
total_steps=total_steps,
|
||||||
|
current_epoch=float(state.epoch or 0),
|
||||||
|
total_epochs=float(parent.total_epochs),
|
||||||
|
eval_loss=eval_loss,
|
||||||
|
progress_percent=float(progress),
|
||||||
|
status="training",
|
||||||
|
message=f"Evaluation complete at step {state.global_step}",
|
||||||
|
extra_metrics=extra_metrics,
|
||||||
|
)
|
||||||
|
parent.progress_queue.put(update)
|
||||||
|
|
||||||
|
def on_save(self, args, state, control, **kwargs):
|
||||||
|
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=parent.job_id,
|
||||||
|
current_step=state.global_step,
|
||||||
|
status="saving",
|
||||||
|
message=f"Checkpoint saved at step {state.global_step}",
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
)
|
||||||
|
parent.progress_queue.put(update)
|
||||||
|
|
||||||
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=parent.job_id,
|
||||||
|
current_step=state.global_step,
|
||||||
|
total_steps=state.max_steps,
|
||||||
|
progress_percent=100.0,
|
||||||
|
status="completed",
|
||||||
|
message="Training completed",
|
||||||
|
)
|
||||||
|
parent.progress_queue.put(update)
|
||||||
|
|
||||||
|
return _Callback()
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveJob:
|
||||||
|
"""Represents an active fine-tuning job."""
|
||||||
|
|
||||||
|
def __init__(self, job_id):
|
||||||
|
self.job_id = job_id
|
||||||
|
self.progress_queue = queue.Queue()
|
||||||
|
self.trainer = None
|
||||||
|
self.thread = None
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
self.error = None
|
||||||
|
self.completed = False
|
||||||
|
self.stopped = False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_gated_repo_error(exc):
|
||||||
|
"""Check if an exception is caused by a gated HuggingFace repo requiring authentication."""
|
||||||
|
try:
|
||||||
|
from huggingface_hub.utils import GatedRepoError
|
||||||
|
if isinstance(exc, GatedRepoError):
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
msg = str(exc).lower()
|
||||||
|
if "gated repo" in msg or "access to model" in msg:
|
||||||
|
return True
|
||||||
|
if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'):
|
||||||
|
if exc.response.status_code in (401, 403):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
|
def __init__(self):
|
||||||
|
self.active_job = None
|
||||||
|
|
||||||
|
def Health(self, request, context):
|
||||||
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||||
|
|
||||||
|
def LoadModel(self, request, context):
|
||||||
|
"""Accept LoadModel — actual model loading happens in StartFineTune."""
|
||||||
|
return backend_pb2.Result(success=True, message="OK")
|
||||||
|
|
||||||
|
def StartFineTune(self, request, context):
|
||||||
|
if self.active_job is not None and not self.active_job.completed:
|
||||||
|
return backend_pb2.FineTuneJobResult(
|
||||||
|
job_id="",
|
||||||
|
success=False,
|
||||||
|
message="A fine-tuning job is already running",
|
||||||
|
)
|
||||||
|
|
||||||
|
job_id = request.job_id if request.job_id else str(uuid.uuid4())
|
||||||
|
job = ActiveJob(job_id)
|
||||||
|
self.active_job = job
|
||||||
|
|
||||||
|
# Start training in background thread
|
||||||
|
thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
|
||||||
|
job.thread = thread
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
return backend_pb2.FineTuneJobResult(
|
||||||
|
job_id=job_id,
|
||||||
|
success=True,
|
||||||
|
message="Fine-tuning job started",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_training(self, request, job):
|
||||||
|
try:
|
||||||
|
self._do_training(request, job)
|
||||||
|
except Exception as e:
|
||||||
|
if _is_gated_repo_error(e):
|
||||||
|
msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||||
|
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||||
|
else:
|
||||||
|
msg = f"Training failed: {e}"
|
||||||
|
job.error = msg
|
||||||
|
job.completed = True
|
||||||
|
update = backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=job.job_id,
|
||||||
|
status="failed",
|
||||||
|
message=msg,
|
||||||
|
)
|
||||||
|
job.progress_queue.put(update)
|
||||||
|
# Send sentinel
|
||||||
|
job.progress_queue.put(None)
|
||||||
|
|
||||||
|
def _do_training(self, request, job):
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from datasets import load_dataset, Dataset
|
||||||
|
|
||||||
|
extra = dict(request.extra_options)
|
||||||
|
training_method = request.training_method or "sft"
|
||||||
|
training_type = request.training_type or "lora"
|
||||||
|
|
||||||
|
# Send loading status
|
||||||
|
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
|
||||||
|
))
|
||||||
|
|
||||||
|
# Determine device and dtype
|
||||||
|
device_map = "auto" if torch.cuda.is_available() else "cpu"
|
||||||
|
dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
|
||||||
|
|
||||||
|
# HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
|
||||||
|
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
|
||||||
|
if hf_token:
|
||||||
|
model_kwargs["token"] = hf_token
|
||||||
|
if extra.get("trust_remote_code", "false").lower() == "true":
|
||||||
|
model_kwargs["trust_remote_code"] = True
|
||||||
|
if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
job.model = model
|
||||||
|
job.tokenizer = tokenizer
|
||||||
|
|
||||||
|
# Apply LoRA if requested
|
||||||
|
if training_type == "lora":
|
||||||
|
from peft import LoraConfig, get_peft_model
|
||||||
|
lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
|
||||||
|
lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
|
||||||
|
lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
|
||||||
|
|
||||||
|
target_modules = list(request.target_modules) if request.target_modules else None
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
target_modules=target_modules or "all-linear",
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
|
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=job.job_id, status="loading_dataset", message="Loading dataset",
|
||||||
|
))
|
||||||
|
|
||||||
|
dataset_split = request.dataset_split or "train"
|
||||||
|
if os.path.exists(request.dataset_source):
|
||||||
|
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||||
|
dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split)
|
||||||
|
elif request.dataset_source.endswith('.csv'):
|
||||||
|
dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split)
|
||||||
|
else:
|
||||||
|
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||||
|
else:
|
||||||
|
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||||
|
|
||||||
|
# Eval dataset setup
|
||||||
|
eval_dataset = None
|
||||||
|
eval_strategy = extra.get("eval_strategy", "steps")
|
||||||
|
eval_steps = int(extra.get("eval_steps", str(request.save_steps if request.save_steps > 0 else 500)))
|
||||||
|
|
||||||
|
if eval_strategy != "no":
|
||||||
|
eval_split = extra.get("eval_split")
|
||||||
|
eval_dataset_source = extra.get("eval_dataset_source")
|
||||||
|
if eval_split:
|
||||||
|
# Load a specific split as eval dataset
|
||||||
|
if os.path.exists(request.dataset_source):
|
||||||
|
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||||
|
eval_dataset = load_dataset("json", data_files=request.dataset_source, split=eval_split)
|
||||||
|
elif request.dataset_source.endswith('.csv'):
|
||||||
|
eval_dataset = load_dataset("csv", data_files=request.dataset_source, split=eval_split)
|
||||||
|
else:
|
||||||
|
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||||
|
else:
|
||||||
|
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||||
|
elif eval_dataset_source:
|
||||||
|
# Load eval dataset from a separate source
|
||||||
|
eval_dataset = load_dataset(eval_dataset_source, split="train")
|
||||||
|
else:
|
||||||
|
# Auto-split from training set
|
||||||
|
eval_split_ratio = float(extra.get("eval_split_ratio", "0.1"))
|
||||||
|
split = dataset.train_test_split(test_size=eval_split_ratio)
|
||||||
|
dataset = split["train"]
|
||||||
|
eval_dataset = split["test"]
|
||||||
|
|
||||||
|
if eval_strategy == "no":
|
||||||
|
eval_dataset = None
|
||||||
|
|
||||||
|
# Training config
|
||||||
|
output_dir = request.output_dir or f"./output-{job.job_id}"
|
||||||
|
num_epochs = request.num_epochs if request.num_epochs > 0 else 3
|
||||||
|
batch_size = request.batch_size if request.batch_size > 0 else 2
|
||||||
|
lr = request.learning_rate if request.learning_rate > 0 else 2e-4
|
||||||
|
grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4
|
||||||
|
warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5
|
||||||
|
weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01
|
||||||
|
max_steps = request.max_steps if request.max_steps > 0 else -1
|
||||||
|
save_steps = request.save_steps if request.save_steps > 0 else 500
|
||||||
|
seed = request.seed if request.seed > 0 else 3407
|
||||||
|
optimizer = request.optimizer or "adamw_torch"
|
||||||
|
|
||||||
|
# Checkpoint save controls
|
||||||
|
save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited
|
||||||
|
save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no
|
||||||
|
|
||||||
|
# CPU vs GPU training args (can be overridden via extra_options)
|
||||||
|
use_cpu = not torch.cuda.is_available()
|
||||||
|
common_train_kwargs = {}
|
||||||
|
if use_cpu:
|
||||||
|
common_train_kwargs["use_cpu"] = True
|
||||||
|
common_train_kwargs["fp16"] = False
|
||||||
|
common_train_kwargs["bf16"] = False
|
||||||
|
common_train_kwargs["gradient_checkpointing"] = False
|
||||||
|
else:
|
||||||
|
common_train_kwargs["bf16"] = True
|
||||||
|
common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing
|
||||||
|
|
||||||
|
# Allow extra_options to override training kwargs
|
||||||
|
for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"):
|
||||||
|
if flag in extra:
|
||||||
|
common_train_kwargs[flag] = extra[flag].lower() == "true"
|
||||||
|
|
||||||
|
# Create progress callback
|
||||||
|
progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs)
|
||||||
|
|
||||||
|
# Build save kwargs (shared across all methods)
|
||||||
|
_save_kwargs = {}
|
||||||
|
if save_strategy == "steps" and save_steps > 0:
|
||||||
|
_save_kwargs["save_steps"] = save_steps
|
||||||
|
_save_kwargs["save_strategy"] = "steps"
|
||||||
|
elif save_strategy == "epoch":
|
||||||
|
_save_kwargs["save_strategy"] = "epoch"
|
||||||
|
elif save_strategy == "no":
|
||||||
|
_save_kwargs["save_strategy"] = "no"
|
||||||
|
else:
|
||||||
|
_save_kwargs["save_steps"] = save_steps
|
||||||
|
_save_kwargs["save_strategy"] = "steps"
|
||||||
|
if save_total_limit:
|
||||||
|
_save_kwargs["save_total_limit"] = save_total_limit
|
||||||
|
|
||||||
|
# Eval kwargs
|
||||||
|
_eval_kwargs = {}
|
||||||
|
if eval_dataset is not None:
|
||||||
|
_eval_kwargs["eval_strategy"] = eval_strategy
|
||||||
|
_eval_kwargs["eval_steps"] = eval_steps
|
||||||
|
|
||||||
|
# Common training arguments shared by all methods
|
||||||
|
_common_args = dict(
|
||||||
|
output_dir=output_dir,
|
||||||
|
num_train_epochs=num_epochs,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
learning_rate=lr,
|
||||||
|
gradient_accumulation_steps=grad_accum,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
max_steps=max_steps,
|
||||||
|
seed=seed,
|
||||||
|
optim=optimizer,
|
||||||
|
logging_steps=1,
|
||||||
|
report_to="none",
|
||||||
|
**_save_kwargs,
|
||||||
|
**common_train_kwargs,
|
||||||
|
**_eval_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select trainer based on training method
|
||||||
|
if training_method == "sft":
|
||||||
|
from trl import SFTTrainer, SFTConfig
|
||||||
|
|
||||||
|
max_length = int(extra.get("max_seq_length", "512"))
|
||||||
|
packing = extra.get("packing", "false").lower() == "true"
|
||||||
|
|
||||||
|
training_args = SFTConfig(
|
||||||
|
max_length=max_length,
|
||||||
|
packing=packing,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "dpo":
|
||||||
|
from trl import DPOTrainer, DPOConfig
|
||||||
|
|
||||||
|
beta = float(extra.get("beta", "0.1"))
|
||||||
|
loss_type = extra.get("loss_type", "sigmoid")
|
||||||
|
max_length = int(extra.get("max_length", "512"))
|
||||||
|
|
||||||
|
training_args = DPOConfig(
|
||||||
|
beta=beta,
|
||||||
|
loss_type=loss_type,
|
||||||
|
max_length=max_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = DPOTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "grpo":
|
||||||
|
from trl import GRPOTrainer, GRPOConfig
|
||||||
|
|
||||||
|
num_generations = int(extra.get("num_generations", "4"))
|
||||||
|
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||||
|
|
||||||
|
training_args = GRPOConfig(
|
||||||
|
num_generations=num_generations,
|
||||||
|
max_completion_length=max_completion_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# GRPO requires reward functions passed via extra_options as a JSON list
|
||||||
|
from reward_functions import build_reward_functions
|
||||||
|
|
||||||
|
reward_funcs = []
|
||||||
|
if extra.get("reward_funcs"):
|
||||||
|
reward_funcs = build_reward_functions(extra["reward_funcs"])
|
||||||
|
|
||||||
|
if not reward_funcs:
|
||||||
|
raise ValueError(
|
||||||
|
"GRPO requires at least one reward function. "
|
||||||
|
"Specify reward_functions in the request or "
|
||||||
|
"reward_funcs in extra_options."
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = GRPOTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
reward_funcs=reward_funcs,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "orpo":
|
||||||
|
from trl import ORPOTrainer, ORPOConfig
|
||||||
|
|
||||||
|
beta = float(extra.get("beta", "0.1"))
|
||||||
|
max_length = int(extra.get("max_length", "512"))
|
||||||
|
|
||||||
|
training_args = ORPOConfig(
|
||||||
|
beta=beta,
|
||||||
|
max_length=max_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = ORPOTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "kto":
|
||||||
|
from trl import KTOTrainer, KTOConfig
|
||||||
|
|
||||||
|
beta = float(extra.get("beta", "0.1"))
|
||||||
|
max_length = int(extra.get("max_length", "512"))
|
||||||
|
|
||||||
|
training_args = KTOConfig(
|
||||||
|
beta=beta,
|
||||||
|
max_length=max_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = KTOTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "rloo":
|
||||||
|
from trl import RLOOTrainer, RLOOConfig
|
||||||
|
|
||||||
|
num_generations = int(extra.get("num_generations", "4"))
|
||||||
|
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||||
|
|
||||||
|
training_args = RLOOConfig(
|
||||||
|
num_generations=num_generations,
|
||||||
|
max_new_tokens=max_completion_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = RLOOTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif training_method == "reward":
|
||||||
|
from trl import RewardTrainer, RewardConfig
|
||||||
|
|
||||||
|
max_length = int(extra.get("max_length", "512"))
|
||||||
|
|
||||||
|
training_args = RewardConfig(
|
||||||
|
max_length=max_length,
|
||||||
|
**_common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
callbacks=[progress_cb.get_callback()],
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported training method: {training_method}. "
|
||||||
|
"Supported: sft, dpo, grpo, orpo, kto, rloo, reward")
|
||||||
|
|
||||||
|
job.trainer = trainer
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||||
|
job_id=job.job_id, status="training", message="Training started",
|
||||||
|
))
|
||||||
|
|
||||||
|
resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None
|
||||||
|
trainer.train(resume_from_checkpoint=resume_ckpt)
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
trainer.save_model(output_dir)
|
||||||
|
if tokenizer:
|
||||||
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
job.completed = True
|
||||||
|
# Sentinel to signal stream end
|
||||||
|
job.progress_queue.put(None)
|
||||||
|
|
||||||
|
def FineTuneProgress(self, request, context):
|
||||||
|
if self.active_job is None or self.active_job.job_id != request.job_id:
|
||||||
|
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||||
|
context.set_details(f"Job {request.job_id} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
job = self.active_job
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
update = job.progress_queue.get(timeout=1.0)
|
||||||
|
if update is None:
|
||||||
|
break
|
||||||
|
yield update
|
||||||
|
if update.status in ("completed", "failed", "stopped"):
|
||||||
|
break
|
||||||
|
except queue.Empty:
|
||||||
|
if job.completed or job.stopped:
|
||||||
|
break
|
||||||
|
if not context.is_active():
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
def StopFineTune(self, request, context):
|
||||||
|
# Stopping is handled by killing the process from Go via ShutdownModel.
|
||||||
|
return backend_pb2.Result(success=True, message="OK")
|
||||||
|
|
||||||
|
def ListCheckpoints(self, request, context):
|
||||||
|
output_dir = request.output_dir
|
||||||
|
if not os.path.isdir(output_dir):
|
||||||
|
return backend_pb2.ListCheckpointsResponse(checkpoints=[])
|
||||||
|
|
||||||
|
checkpoints = []
|
||||||
|
for entry in sorted(os.listdir(output_dir)):
|
||||||
|
if entry.startswith("checkpoint-"):
|
||||||
|
ckpt_path = os.path.join(output_dir, entry)
|
||||||
|
if not os.path.isdir(ckpt_path):
|
||||||
|
continue
|
||||||
|
step = 0
|
||||||
|
try:
|
||||||
|
step = int(entry.split("-")[1])
|
||||||
|
except (IndexError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try to read trainer_state.json for metadata
|
||||||
|
loss = 0.0
|
||||||
|
epoch = 0.0
|
||||||
|
state_file = os.path.join(ckpt_path, "trainer_state.json")
|
||||||
|
if os.path.exists(state_file):
|
||||||
|
try:
|
||||||
|
with open(state_file) as f:
|
||||||
|
state = json.load(f)
|
||||||
|
if state.get("log_history"):
|
||||||
|
last_log = state["log_history"][-1]
|
||||||
|
loss = last_log.get("loss", 0.0)
|
||||||
|
epoch = last_log.get("epoch", 0.0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
created_at = time.strftime(
|
||||||
|
"%Y-%m-%dT%H:%M:%SZ",
|
||||||
|
time.gmtime(os.path.getmtime(ckpt_path)),
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints.append(backend_pb2.CheckpointInfo(
|
||||||
|
path=ckpt_path,
|
||||||
|
step=step,
|
||||||
|
epoch=float(epoch),
|
||||||
|
loss=float(loss),
|
||||||
|
created_at=created_at,
|
||||||
|
))
|
||||||
|
|
||||||
|
return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints)
|
||||||
|
|
||||||
|
def ExportModel(self, request, context):
|
||||||
|
export_format = request.export_format or "lora"
|
||||||
|
output_path = request.output_path
|
||||||
|
checkpoint_path = request.checkpoint_path
|
||||||
|
|
||||||
|
# Extract HF token for gated model access
|
||||||
|
extra = dict(request.extra_options) if request.extra_options else {}
|
||||||
|
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||||
|
|
||||||
|
if not checkpoint_path or not os.path.isdir(checkpoint_path):
|
||||||
|
return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}")
|
||||||
|
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if export_format == "lora":
|
||||||
|
# Just copy the adapter files
|
||||||
|
import shutil
|
||||||
|
for f in os.listdir(checkpoint_path):
|
||||||
|
src = os.path.join(checkpoint_path, f)
|
||||||
|
dst = os.path.join(output_path, f)
|
||||||
|
if os.path.isfile(src):
|
||||||
|
shutil.copy2(src, dst)
|
||||||
|
|
||||||
|
elif export_format in ("merged_16bit", "merged_4bit"):
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
base_model_name = request.model
|
||||||
|
if not base_model_name:
|
||||||
|
return backend_pb2.Result(success=False, message="Base model name required for merge export")
|
||||||
|
|
||||||
|
dtype = torch.float16 if export_format == "merged_16bit" else torch.float32
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token)
|
||||||
|
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||||
|
merged = model.merge_and_unload()
|
||||||
|
merged.save_pretrained(output_path)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||||
|
tokenizer.save_pretrained(output_path)
|
||||||
|
|
||||||
|
elif export_format == "gguf":
|
||||||
|
import torch
|
||||||
|
import subprocess
|
||||||
|
import shutil
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
base_model_name = request.model
|
||||||
|
if not base_model_name:
|
||||||
|
return backend_pb2.Result(success=False, message="Base model name required for GGUF export")
|
||||||
|
|
||||||
|
# Step 1: Merge LoRA into base model
|
||||||
|
merge_dir = os.path.join(output_path, "_hf_merged")
|
||||||
|
os.makedirs(merge_dir, exist_ok=True)
|
||||||
|
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token)
|
||||||
|
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||||
|
merged = model.merge_and_unload()
|
||||||
|
merged.save_pretrained(merge_dir)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||||
|
tokenizer.save_pretrained(merge_dir)
|
||||||
|
|
||||||
|
# Ensure tokenizer.model (SentencePiece) is present in merge_dir.
|
||||||
|
# Gemma models need this file for GGUF conversion to use the
|
||||||
|
# SentencePiece path; without it, the script falls back to BPE
|
||||||
|
# handling which fails on unrecognized pre-tokenizer hashes.
|
||||||
|
sp_model_path = os.path.join(merge_dir, "tokenizer.model")
|
||||||
|
if not os.path.exists(sp_model_path):
|
||||||
|
sp_copied = False
|
||||||
|
# Method 1: Load the slow tokenizer which keeps the SP model file
|
||||||
|
try:
|
||||||
|
slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token)
|
||||||
|
if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file):
|
||||||
|
import shutil as _shutil
|
||||||
|
_shutil.copy2(slow_tok.vocab_file, sp_model_path)
|
||||||
|
sp_copied = True
|
||||||
|
print(f"Copied tokenizer.model from slow tokenizer cache")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Slow tokenizer method failed: {e}")
|
||||||
|
# Method 2: Download from HF hub
|
||||||
|
if not sp_copied:
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token)
|
||||||
|
import shutil as _shutil
|
||||||
|
_shutil.copy2(cached_sp, sp_model_path)
|
||||||
|
sp_copied = True
|
||||||
|
print(f"Copied tokenizer.model from HF hub")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"HF hub download method failed: {e}")
|
||||||
|
if not sp_copied:
|
||||||
|
print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. "
|
||||||
|
"GGUF conversion may fail for SentencePiece models.")
|
||||||
|
|
||||||
|
# Free GPU memory before conversion
|
||||||
|
del merged, model, base_model
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Step 2: Convert to GGUF using convert_hf_to_gguf.py
|
||||||
|
quant = request.quantization_method or "auto"
|
||||||
|
outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"}
|
||||||
|
outtype = outtype_map.get(quant, "f16")
|
||||||
|
|
||||||
|
gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf"
|
||||||
|
gguf_path = os.path.join(output_path, gguf_filename)
|
||||||
|
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py")
|
||||||
|
if not os.path.exists(convert_script):
|
||||||
|
return backend_pb2.Result(success=False,
|
||||||
|
message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.")
|
||||||
|
|
||||||
|
# Log merge_dir contents for debugging conversion issues
|
||||||
|
merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else []
|
||||||
|
print(f"Merge dir contents: {merge_files}", flush=True)
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["NO_LOCAL_GGUF"] = "1"
|
||||||
|
cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path]
|
||||||
|
conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
|
||||||
|
if conv_result.returncode != 0:
|
||||||
|
diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}"
|
||||||
|
return backend_pb2.Result(success=False,
|
||||||
|
message=f"GGUF conversion failed: {diag}")
|
||||||
|
|
||||||
|
# Clean up intermediate merged model
|
||||||
|
shutil.rmtree(merge_dir, ignore_errors=True)
|
||||||
|
else:
|
||||||
|
return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if _is_gated_repo_error(e):
|
||||||
|
return backend_pb2.Result(success=False,
|
||||||
|
message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||||
|
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||||
|
return backend_pb2.Result(success=False, message=f"Export failed: {e}")
|
||||||
|
|
||||||
|
return backend_pb2.Result(success=True, message=f"Model exported to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def serve(address):
|
||||||
|
server = grpc.server(
|
||||||
|
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||||
|
options=[
|
||||||
|
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||||
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||||
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
|
server.add_insecure_port(address)
|
||||||
|
server.start()
|
||||||
|
print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True)
|
||||||
|
|
||||||
|
# Handle graceful shutdown
|
||||||
|
def stop(signum, frame):
|
||||||
|
server.stop(0)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, stop)
|
||||||
|
signal.signal(signal.SIGINT, stop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
server.stop(0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend")
|
||||||
|
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||||
|
args = parser.parse_args()
|
||||||
|
serve(args.addr)
|
||||||
37
backend/python/trl/install.sh
Normal file
37
backend/python/trl/install.sh
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||||
|
installRequirements
|
||||||
|
|
||||||
|
# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version
|
||||||
|
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
||||||
|
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
||||||
|
if [ ! -f "${CONVERT_SCRIPT}" ]; then
|
||||||
|
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||||
|
curl -L --fail --retry 3 \
|
||||||
|
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
|
||||||
|
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install gguf package from the same llama.cpp commit to keep them in sync
|
||||||
|
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
|
||||||
|
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||||
|
if [ "x${USE_PIP:-}" == "xtrue" ]; then
|
||||||
|
pip install "${GGUF_PIP_SPEC}" || {
|
||||||
|
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||||
|
pip install "gguf>=0.16.0"
|
||||||
|
}
|
||||||
|
else
|
||||||
|
uv pip install "${GGUF_PIP_SPEC}" || {
|
||||||
|
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||||
|
uv pip install "gguf>=0.16.0"
|
||||||
|
}
|
||||||
|
fi
|
||||||
9
backend/python/trl/requirements-cpu.txt
Normal file
9
backend/python/trl/requirements-cpu.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
torch==2.10.0
|
||||||
|
trl
|
||||||
|
peft
|
||||||
|
datasets>=3.0.0
|
||||||
|
transformers>=4.56.2
|
||||||
|
accelerate>=1.4.0
|
||||||
|
huggingface-hub>=1.3.0
|
||||||
|
sentencepiece
|
||||||
9
backend/python/trl/requirements-cublas12.txt
Normal file
9
backend/python/trl/requirements-cublas12.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
torch==2.10.0
|
||||||
|
trl
|
||||||
|
peft
|
||||||
|
datasets>=3.0.0
|
||||||
|
transformers>=4.56.2
|
||||||
|
accelerate>=1.4.0
|
||||||
|
huggingface-hub>=1.3.0
|
||||||
|
sentencepiece
|
||||||
|
bitsandbytes
|
||||||
9
backend/python/trl/requirements-cublas13.txt
Normal file
9
backend/python/trl/requirements-cublas13.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
torch==2.10.0
|
||||||
|
trl
|
||||||
|
peft
|
||||||
|
datasets>=3.0.0
|
||||||
|
transformers>=4.56.2
|
||||||
|
accelerate>=1.4.0
|
||||||
|
huggingface-hub>=1.3.0
|
||||||
|
sentencepiece
|
||||||
|
bitsandbytes
|
||||||
3
backend/python/trl/requirements.txt
Normal file
3
backend/python/trl/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
grpcio==1.78.1
|
||||||
|
protobuf
|
||||||
|
certifi
|
||||||
236
backend/python/trl/reward_functions.py
Normal file
236
backend/python/trl/reward_functions.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""
|
||||||
|
Built-in reward functions and inline function compiler for GRPO training.
|
||||||
|
|
||||||
|
All reward functions follow TRL's signature: (completions, **kwargs) -> list[float]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import math
|
||||||
|
import string
|
||||||
|
import functools
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Built-in reward functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def format_reward(completions, **kwargs):
|
||||||
|
"""Checks for <think>...</think> followed by an answer. Returns 1.0 or 0.0."""
|
||||||
|
pattern = re.compile(r"<think>.*?</think>\s*\S", re.DOTALL)
|
||||||
|
return [1.0 if pattern.search(c) else 0.0 for c in completions]
|
||||||
|
|
||||||
|
|
||||||
|
def reasoning_accuracy_reward(completions, **kwargs):
|
||||||
|
"""Extracts <answer>...</answer> content and compares to the expected answer."""
|
||||||
|
answers = kwargs.get("answer", [])
|
||||||
|
if not answers:
|
||||||
|
return [0.0] * len(completions)
|
||||||
|
|
||||||
|
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
|
||||||
|
scores = []
|
||||||
|
for i, c in enumerate(completions):
|
||||||
|
expected = answers[i] if i < len(answers) else ""
|
||||||
|
match = pattern.search(c)
|
||||||
|
if match:
|
||||||
|
extracted = match.group(1).strip()
|
||||||
|
scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0)
|
||||||
|
else:
|
||||||
|
scores.append(0.0)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def length_reward(completions, target_length=200, **kwargs):
|
||||||
|
"""Score based on proximity to target_length. Returns [0, 1]."""
|
||||||
|
scores = []
|
||||||
|
for c in completions:
|
||||||
|
length = len(c)
|
||||||
|
if target_length <= 0:
|
||||||
|
scores.append(0.0)
|
||||||
|
else:
|
||||||
|
diff = abs(length - target_length) / target_length
|
||||||
|
scores.append(max(0.0, 1.0 - diff))
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def xml_tag_reward(completions, **kwargs):
|
||||||
|
"""Scores properly opened/closed XML tags (<think>, <answer>)."""
|
||||||
|
tags = ["think", "answer"]
|
||||||
|
scores = []
|
||||||
|
for c in completions:
|
||||||
|
tag_score = 0.0
|
||||||
|
for tag in tags:
|
||||||
|
if f"<{tag}>" in c and f"</{tag}>" in c:
|
||||||
|
tag_score += 0.5
|
||||||
|
scores.append(min(tag_score, 1.0))
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def no_repetition_reward(completions, n=4, **kwargs):
|
||||||
|
"""Penalizes n-gram repetition. Returns [0, 1]."""
|
||||||
|
scores = []
|
||||||
|
for c in completions:
|
||||||
|
words = c.split()
|
||||||
|
if len(words) < n:
|
||||||
|
scores.append(1.0)
|
||||||
|
continue
|
||||||
|
ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
|
||||||
|
unique = len(set(ngrams))
|
||||||
|
total = len(ngrams)
|
||||||
|
scores.append(unique / total if total > 0 else 1.0)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def code_execution_reward(completions, **kwargs):
|
||||||
|
"""Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0."""
|
||||||
|
pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL)
|
||||||
|
scores = []
|
||||||
|
for c in completions:
|
||||||
|
match = pattern.search(c)
|
||||||
|
if not match:
|
||||||
|
scores.append(0.0)
|
||||||
|
continue
|
||||||
|
code = match.group(1)
|
||||||
|
try:
|
||||||
|
compile(code, "<inline>", "exec")
|
||||||
|
scores.append(1.0)
|
||||||
|
except SyntaxError:
|
||||||
|
scores.append(0.0)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Registry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
BUILTIN_REGISTRY = {
|
||||||
|
"format_reward": format_reward,
|
||||||
|
"reasoning_accuracy_reward": reasoning_accuracy_reward,
|
||||||
|
"length_reward": length_reward,
|
||||||
|
"xml_tag_reward": xml_tag_reward,
|
||||||
|
"no_repetition_reward": no_repetition_reward,
|
||||||
|
"code_execution_reward": code_execution_reward,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Inline function compiler
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_SAFE_BUILTINS = {
|
||||||
|
"len": len, "int": int, "float": float, "str": str, "bool": bool,
|
||||||
|
"list": list, "dict": dict, "tuple": tuple, "set": set,
|
||||||
|
"range": range, "enumerate": enumerate, "zip": zip,
|
||||||
|
"map": map, "filter": filter, "sorted": sorted,
|
||||||
|
"min": min, "max": max, "sum": sum, "abs": abs, "round": round,
|
||||||
|
"any": any, "all": all, "isinstance": isinstance, "type": type,
|
||||||
|
"print": print, "True": True, "False": False, "None": None,
|
||||||
|
"ValueError": ValueError, "TypeError": TypeError,
|
||||||
|
"KeyError": KeyError, "IndexError": IndexError,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compile_inline_reward(name, code):
|
||||||
|
"""Compile user-provided code into a reward function.
|
||||||
|
|
||||||
|
The code should be the body of a function that receives
|
||||||
|
`completions` (list[str]) and `**kwargs`, and returns list[float].
|
||||||
|
|
||||||
|
Available modules: re, math, json, string.
|
||||||
|
"""
|
||||||
|
func_source = (
|
||||||
|
f"def _user_reward_{name}(completions, **kwargs):\n"
|
||||||
|
+ "\n".join(f" {line}" for line in code.splitlines())
|
||||||
|
)
|
||||||
|
|
||||||
|
restricted_globals = {
|
||||||
|
"__builtins__": _SAFE_BUILTINS,
|
||||||
|
"re": re,
|
||||||
|
"math": math,
|
||||||
|
"json": json,
|
||||||
|
"string": string,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
compiled = compile(func_source, f"<inline-reward-{name}>", "exec")
|
||||||
|
except SyntaxError as e:
|
||||||
|
raise ValueError(f"Syntax error in inline reward function '{name}': {e}")
|
||||||
|
|
||||||
|
exec(compiled, restricted_globals)
|
||||||
|
func = restricted_globals[f"_user_reward_{name}"]
|
||||||
|
|
||||||
|
# Validate with a quick smoke test
|
||||||
|
try:
|
||||||
|
result = func(["test"], answer=["test"])
|
||||||
|
if not isinstance(result, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"Inline reward function '{name}' must return a list, got {type(result).__name__}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if "must return a list" in str(e):
|
||||||
|
raise
|
||||||
|
# Other errors during smoke test are acceptable (e.g. missing kwargs)
|
||||||
|
pass
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dispatcher
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def build_reward_functions(specs_json):
|
||||||
|
"""Parse a JSON list of reward function specs and return a list of callables.
|
||||||
|
|
||||||
|
Each spec is a dict with:
|
||||||
|
- type: "builtin" or "inline"
|
||||||
|
- name: function name
|
||||||
|
- code: (inline only) Python function body
|
||||||
|
- params: (optional) dict of string params applied via functools.partial
|
||||||
|
"""
|
||||||
|
if isinstance(specs_json, str):
|
||||||
|
specs = json.loads(specs_json)
|
||||||
|
else:
|
||||||
|
specs = specs_json
|
||||||
|
|
||||||
|
if not isinstance(specs, list):
|
||||||
|
raise ValueError("reward_funcs must be a JSON array of reward function specs")
|
||||||
|
|
||||||
|
reward_funcs = []
|
||||||
|
for spec in specs:
|
||||||
|
spec_type = spec.get("type", "builtin")
|
||||||
|
name = spec.get("name", "")
|
||||||
|
params = spec.get("params", {})
|
||||||
|
|
||||||
|
if spec_type == "builtin":
|
||||||
|
if name not in BUILTIN_REGISTRY:
|
||||||
|
available = ", ".join(sorted(BUILTIN_REGISTRY.keys()))
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown builtin reward function '{name}'. Available: {available}"
|
||||||
|
)
|
||||||
|
func = BUILTIN_REGISTRY[name]
|
||||||
|
if params:
|
||||||
|
# Convert string params to appropriate types
|
||||||
|
typed_params = {}
|
||||||
|
for k, v in params.items():
|
||||||
|
try:
|
||||||
|
typed_params[k] = int(v)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
try:
|
||||||
|
typed_params[k] = float(v)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
typed_params[k] = v
|
||||||
|
func = functools.partial(func, **typed_params)
|
||||||
|
reward_funcs.append(func)
|
||||||
|
|
||||||
|
elif spec_type == "inline":
|
||||||
|
code = spec.get("code", "")
|
||||||
|
if not code.strip():
|
||||||
|
raise ValueError(f"Inline reward function '{name}' has no code")
|
||||||
|
func = compile_inline_reward(name, code)
|
||||||
|
reward_funcs.append(func)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'")
|
||||||
|
|
||||||
|
return reward_funcs
|
||||||
10
backend/python/trl/run.sh
Normal file
10
backend/python/trl/run.sh
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
startBackend $@
|
||||||
58
backend/python/trl/test.py
Normal file
58
backend/python/trl/test.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""
|
||||||
|
Test script for the TRL fine-tuning gRPC backend.
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackendServicer(unittest.TestCase):
|
||||||
|
"""Tests for the TRL fine-tuning gRPC service."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.service = subprocess.Popen(
|
||||||
|
["python3", "backend.py", "--addr", "localhost:50051"]
|
||||||
|
)
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.service.kill()
|
||||||
|
self.service.wait()
|
||||||
|
|
||||||
|
def test_server_startup(self):
|
||||||
|
"""Test that the server starts and responds to health checks."""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.Health(backend_pb2.HealthMessage())
|
||||||
|
self.assertEqual(response.message, b'OK')
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("Server failed to start")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_list_checkpoints_empty(self):
|
||||||
|
"""Test listing checkpoints on a non-existent directory."""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.ListCheckpoints(
|
||||||
|
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
|
||||||
|
)
|
||||||
|
self.assertEqual(len(response.checkpoints), 0)
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("ListCheckpoints service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
11
backend/python/trl/test.sh
Normal file
11
backend/python/trl/test.sh
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
runUnittests
|
||||||
@@ -121,6 +121,9 @@ type RunCMD struct {
|
|||||||
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
|
AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"`
|
||||||
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
|
AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"`
|
||||||
|
|
||||||
|
// Fine-tuning
|
||||||
|
EnableFineTuning bool `env:"LOCALAI_ENABLE_FINETUNING" default:"false" help:"Enable fine-tuning support" group:"finetuning"`
|
||||||
|
|
||||||
// Authentication
|
// Authentication
|
||||||
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
|
AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"`
|
||||||
AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"`
|
AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"`
|
||||||
@@ -326,6 +329,11 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
opts = append(opts, config.WithAgentHubURL(r.AgentHubURL))
|
opts = append(opts, config.WithAgentHubURL(r.AgentHubURL))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fine-tuning
|
||||||
|
if r.EnableFineTuning {
|
||||||
|
opts = append(opts, config.EnableFineTuning)
|
||||||
|
}
|
||||||
|
|
||||||
// Authentication
|
// Authentication
|
||||||
authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != ""
|
authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != ""
|
||||||
if authEnabled {
|
if authEnabled {
|
||||||
|
|||||||
@@ -97,6 +97,9 @@ type ApplicationConfig struct {
|
|||||||
// Agent Pool (LocalAGI integration)
|
// Agent Pool (LocalAGI integration)
|
||||||
AgentPool AgentPoolConfig
|
AgentPool AgentPoolConfig
|
||||||
|
|
||||||
|
// Fine-tuning
|
||||||
|
FineTuning FineTuningConfig
|
||||||
|
|
||||||
// Authentication & Authorization
|
// Authentication & Authorization
|
||||||
Auth AuthConfig
|
Auth AuthConfig
|
||||||
}
|
}
|
||||||
@@ -142,6 +145,11 @@ type AgentPoolConfig struct {
|
|||||||
AgentHubURL string // default: "https://agenthub.localai.io"
|
AgentHubURL string // default: "https://agenthub.localai.io"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FineTuningConfig holds configuration for fine-tuning support.
|
||||||
|
type FineTuningConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
type AppOption func(*ApplicationConfig)
|
type AppOption func(*ApplicationConfig)
|
||||||
|
|
||||||
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||||
@@ -733,6 +741,12 @@ func WithAgentHubURL(url string) AppOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fine-tuning options
|
||||||
|
|
||||||
|
var EnableFineTuning = func(o *ApplicationConfig) {
|
||||||
|
o.FineTuning.Enabled = true
|
||||||
|
}
|
||||||
|
|
||||||
// Auth options
|
// Auth options
|
||||||
|
|
||||||
func WithAuthEnabled(enabled bool) AppOption {
|
func WithAuthEnabled(enabled bool) AppOption {
|
||||||
|
|||||||
205
core/gallery/importers/local.go
Normal file
205
core/gallery/importers/local.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package importers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ImportLocalPath scans a local directory for exported model files and produces
|
||||||
|
// a config.ModelConfig with the correct backend, model path, and options.
|
||||||
|
// Paths in the returned config are relative to modelsPath when possible so that
|
||||||
|
// the YAML config remains portable.
|
||||||
|
//
|
||||||
|
// Detection order:
|
||||||
|
// 1. GGUF files (*.gguf) — uses llama-cpp backend
|
||||||
|
// 2. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter
|
||||||
|
// 3. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend
|
||||||
|
func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||||
|
// Make paths relative to the models directory (parent of dirPath)
|
||||||
|
// so config YAML stays portable.
|
||||||
|
modelsDir := filepath.Dir(dirPath)
|
||||||
|
relPath := func(absPath string) string {
|
||||||
|
if rel, err := filepath.Rel(modelsDir, absPath); err == nil {
|
||||||
|
return rel
|
||||||
|
}
|
||||||
|
return absPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. GGUF: check dirPath and dirPath_gguf/ (Unsloth convention)
|
||||||
|
ggufFile := findGGUF(dirPath)
|
||||||
|
if ggufFile == "" {
|
||||||
|
ggufSubdir := dirPath + "_gguf"
|
||||||
|
ggufFile = findGGUF(ggufSubdir)
|
||||||
|
}
|
||||||
|
if ggufFile != "" {
|
||||||
|
xlog.Info("ImportLocalPath: detected GGUF model", "path", ggufFile)
|
||||||
|
cfg := &config.ModelConfig{
|
||||||
|
Name: name,
|
||||||
|
Backend: "llama-cpp",
|
||||||
|
KnownUsecaseStrings: []string{"chat"},
|
||||||
|
Options: []string{"use_jinja:true"},
|
||||||
|
}
|
||||||
|
cfg.Model = relPath(ggufFile)
|
||||||
|
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||||
|
cfg.Description = buildDescription(dirPath, "GGUF")
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. LoRA adapter: look for adapter_config.json
|
||||||
|
|
||||||
|
adapterConfigPath := filepath.Join(dirPath, "adapter_config.json")
|
||||||
|
if fileExists(adapterConfigPath) {
|
||||||
|
xlog.Info("ImportLocalPath: detected LoRA adapter", "path", dirPath)
|
||||||
|
baseModel := readBaseModel(dirPath)
|
||||||
|
cfg := &config.ModelConfig{
|
||||||
|
Name: name,
|
||||||
|
Backend: "transformers",
|
||||||
|
KnownUsecaseStrings: []string{"chat"},
|
||||||
|
}
|
||||||
|
cfg.Model = baseModel
|
||||||
|
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||||
|
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
|
||||||
|
cfg.Description = buildDescription(dirPath, "LoRA adapter")
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also check for adapter_model.safetensors or adapter_model.bin without adapter_config.json
|
||||||
|
if fileExists(filepath.Join(dirPath, "adapter_model.safetensors")) || fileExists(filepath.Join(dirPath, "adapter_model.bin")) {
|
||||||
|
xlog.Info("ImportLocalPath: detected LoRA adapter (by model files)", "path", dirPath)
|
||||||
|
baseModel := readBaseModel(dirPath)
|
||||||
|
cfg := &config.ModelConfig{
|
||||||
|
Name: name,
|
||||||
|
Backend: "transformers",
|
||||||
|
KnownUsecaseStrings: []string{"chat"},
|
||||||
|
}
|
||||||
|
cfg.Model = baseModel
|
||||||
|
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||||
|
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
|
||||||
|
cfg.Description = buildDescription(dirPath, "LoRA adapter")
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Merged model: *.safetensors or pytorch_model*.bin + config.json
|
||||||
|
if fileExists(filepath.Join(dirPath, "config.json")) && (hasFileWithSuffix(dirPath, ".safetensors") || hasFileWithPrefix(dirPath, "pytorch_model")) {
|
||||||
|
xlog.Info("ImportLocalPath: detected merged model", "path", dirPath)
|
||||||
|
cfg := &config.ModelConfig{
|
||||||
|
Name: name,
|
||||||
|
Backend: "transformers",
|
||||||
|
KnownUsecaseStrings: []string{"chat"},
|
||||||
|
}
|
||||||
|
cfg.Model = relPath(dirPath)
|
||||||
|
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||||
|
cfg.Description = buildDescription(dirPath, "merged model")
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("could not detect model format in directory %s", dirPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findGGUF returns the path to the first .gguf file found in dir, or "".
|
||||||
|
func findGGUF(dir string) string {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), ".gguf") {
|
||||||
|
return filepath.Join(dir, e.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// readBaseModel reads the base model name from adapter_config.json or export_metadata.json.
|
||||||
|
func readBaseModel(dirPath string) string {
|
||||||
|
// Try adapter_config.json → base_model_name_or_path (TRL writes this)
|
||||||
|
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
|
||||||
|
var ac map[string]any
|
||||||
|
if json.Unmarshal(data, &ac) == nil {
|
||||||
|
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
|
||||||
|
return bm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try export_metadata.json → base_model (Unsloth writes this)
|
||||||
|
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
|
||||||
|
var meta map[string]any
|
||||||
|
if json.Unmarshal(data, &meta) == nil {
|
||||||
|
if bm, ok := meta["base_model"].(string); ok && bm != "" {
|
||||||
|
return bm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildDescription creates a human-readable description using available metadata.
|
||||||
|
func buildDescription(dirPath, formatLabel string) string {
|
||||||
|
base := ""
|
||||||
|
|
||||||
|
// Try adapter_config.json
|
||||||
|
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
|
||||||
|
var ac map[string]any
|
||||||
|
if json.Unmarshal(data, &ac) == nil {
|
||||||
|
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
|
||||||
|
base = bm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try export_metadata.json
|
||||||
|
if base == "" {
|
||||||
|
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
|
||||||
|
var meta map[string]any
|
||||||
|
if json.Unmarshal(data, &meta) == nil {
|
||||||
|
if bm, ok := meta["base_model"].(string); ok && bm != "" {
|
||||||
|
base = bm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if base != "" {
|
||||||
|
return fmt.Sprintf("Fine-tuned from %s (%s)", base, formatLabel)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("Fine-tuned model (%s)", formatLabel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileExists(path string) bool {
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
return err == nil && !info.IsDir()
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasFileWithSuffix(dir, suffix string) bool {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), suffix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasFileWithPrefix(dir, prefix string) bool {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
if !e.IsDir() && strings.HasPrefix(e.Name(), prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
148
core/gallery/importers/local_test.go
Normal file
148
core/gallery/importers/local_test.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package importers_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("ImportLocalPath", func() {
|
||||||
|
var tmpDir string
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
tmpDir, err = os.MkdirTemp("", "importers-local-test")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
os.RemoveAll(tmpDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("GGUF detection", func() {
|
||||||
|
It("detects a GGUF file in the directory", func() {
|
||||||
|
modelDir := filepath.Join(tmpDir, "my-model")
|
||||||
|
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||||
|
Expect(os.WriteFile(filepath.Join(modelDir, "model-q4_k_m.gguf"), []byte("fake"), 0644)).To(Succeed())
|
||||||
|
|
||||||
|
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(cfg.Backend).To(Equal("llama-cpp"))
|
||||||
|
Expect(cfg.Model).To(ContainSubstring(".gguf"))
|
||||||
|
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||||
|
Expect(cfg.KnownUsecaseStrings).To(ContainElement("chat"))
|
||||||
|
Expect(cfg.Options).To(ContainElement("use_jinja:true"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("detects GGUF in _gguf subdirectory", func() {
|
||||||
|
modelDir := filepath.Join(tmpDir, "my-model")
|
||||||
|
ggufDir := modelDir + "_gguf"
|
||||||
|
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||||
|
Expect(os.MkdirAll(ggufDir, 0755)).To(Succeed())
|
||||||
|
Expect(os.WriteFile(filepath.Join(ggufDir, "model.gguf"), []byte("fake"), 0644)).To(Succeed())
|
||||||
|
|
||||||
|
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(cfg.Backend).To(Equal("llama-cpp"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("LoRA adapter detection", func() {
|
||||||
|
It("detects LoRA adapter via adapter_config.json", func() {
|
||||||
|
modelDir := filepath.Join(tmpDir, "lora-model")
|
||||||
|
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||||
|
|
||||||
|
adapterConfig := map[string]any{
|
||||||
|
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
||||||
|
"peft_type": "LORA",
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(adapterConfig)
|
||||||
|
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||||
|
|
||||||
|
cfg, err := importers.ImportLocalPath(modelDir, "lora-model")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(cfg.Backend).To(Equal("transformers"))
|
||||||
|
Expect(cfg.Model).To(Equal("meta-llama/Llama-2-7b-hf"))
|
||||||
|
Expect(cfg.LLMConfig.LoraAdapter).To(Equal("lora-model"))
|
||||||
|
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads base model from export_metadata.json as fallback", func() {
|
||||||
|
modelDir := filepath.Join(tmpDir, "lora-unsloth")
|
||||||
|
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||||
|
|
||||||
|
adapterConfig := map[string]any{"peft_type": "LORA"}
|
||||||
|
data, _ := json.Marshal(adapterConfig)
|
||||||
|
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||||
|
|
||||||
|
metadata := map[string]any{"base_model": "unsloth/tinyllama-bnb-4bit"}
|
||||||
|
data, _ = json.Marshal(metadata)
|
||||||
|
Expect(os.WriteFile(filepath.Join(modelDir, "export_metadata.json"), data, 0644)).To(Succeed())
|
||||||
|
|
||||||
|
cfg, err := importers.ImportLocalPath(modelDir, "lora-unsloth")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(cfg.Model).To(Equal("unsloth/tinyllama-bnb-4bit"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("Merged model detection", func() {
|
||||||
|
It("detects merged model with safetensors + config.json", func() {
|
||||||
|
modelDir := filepath.Join(tmpDir, "merged")
|
||||||
|
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||||
|
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
|
||||||
|
Expect(os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("fake"), 0644)).To(Succeed())
|
||||||
|
|
||||||
|
cfg, err := importers.ImportLocalPath(modelDir, "merged")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(cfg.Backend).To(Equal("transformers"))
|
||||||
|
Expect(cfg.Model).To(Equal("merged"))
|
||||||
|
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("detects merged model with pytorch_model files", func() {
|
||||||
|
modelDir := filepath.Join(tmpDir, "merged-pt")
|
||||||
|
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||||
|
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
|
||||||
|
Expect(os.WriteFile(filepath.Join(modelDir, "pytorch_model-00001-of-00002.bin"), []byte("fake"), 0644)).To(Succeed())
|
||||||
|
|
||||||
|
cfg, err := importers.ImportLocalPath(modelDir, "merged-pt")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(cfg.Backend).To(Equal("transformers"))
|
||||||
|
Expect(cfg.Model).To(Equal("merged-pt"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("fallback", func() {
|
||||||
|
It("returns error for empty directory", func() {
|
||||||
|
modelDir := filepath.Join(tmpDir, "empty")
|
||||||
|
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||||
|
|
||||||
|
_, err := importers.ImportLocalPath(modelDir, "empty")
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("could not detect model format"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("description", func() {
|
||||||
|
It("includes base model name in description", func() {
|
||||||
|
modelDir := filepath.Join(tmpDir, "desc-test")
|
||||||
|
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||||
|
|
||||||
|
adapterConfig := map[string]any{
|
||||||
|
"base_model_name_or_path": "TinyLlama/TinyLlama-1.1B",
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(adapterConfig)
|
||||||
|
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||||
|
|
||||||
|
cfg, err := importers.ImportLocalPath(modelDir, "desc-test")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(cfg.Description).To(ContainSubstring("TinyLlama/TinyLlama-1.1B"))
|
||||||
|
Expect(cfg.Description).To(ContainSubstring("Fine-tuned from"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -302,6 +302,17 @@ func API(application *application.Application) (*echo.Echo, error) {
|
|||||||
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
|
mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP)
|
||||||
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw, mcpMw)
|
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw, mcpMw)
|
||||||
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
|
routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw)
|
||||||
|
// Fine-tuning routes
|
||||||
|
if application.ApplicationConfig().FineTuning.Enabled {
|
||||||
|
fineTuningMw := auth.RequireFeature(application.AuthDB(), auth.FeatureFineTuning)
|
||||||
|
ftService := services.NewFineTuneService(
|
||||||
|
application.ApplicationConfig(),
|
||||||
|
application.ModelLoader(),
|
||||||
|
application.ModelConfigLoader(),
|
||||||
|
)
|
||||||
|
routes.RegisterFineTuningRoutes(e, ftService, application.ApplicationConfig(), fineTuningMw)
|
||||||
|
}
|
||||||
|
|
||||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||||
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
||||||
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
||||||
|
|||||||
@@ -85,6 +85,18 @@ var RouteFeatureRegistry = []RouteFeature{
|
|||||||
{"POST", "/stores/delete", FeatureStores},
|
{"POST", "/stores/delete", FeatureStores},
|
||||||
{"POST", "/stores/get", FeatureStores},
|
{"POST", "/stores/get", FeatureStores},
|
||||||
{"POST", "/stores/find", FeatureStores},
|
{"POST", "/stores/find", FeatureStores},
|
||||||
|
|
||||||
|
// Fine-tuning
|
||||||
|
{"POST", "/api/fine-tuning/jobs", FeatureFineTuning},
|
||||||
|
{"GET", "/api/fine-tuning/jobs", FeatureFineTuning},
|
||||||
|
{"GET", "/api/fine-tuning/jobs/:id", FeatureFineTuning},
|
||||||
|
{"POST", "/api/fine-tuning/jobs/:id/stop", FeatureFineTuning},
|
||||||
|
{"DELETE", "/api/fine-tuning/jobs/:id", FeatureFineTuning},
|
||||||
|
{"GET", "/api/fine-tuning/jobs/:id/progress", FeatureFineTuning},
|
||||||
|
{"GET", "/api/fine-tuning/jobs/:id/checkpoints", FeatureFineTuning},
|
||||||
|
{"POST", "/api/fine-tuning/jobs/:id/export", FeatureFineTuning},
|
||||||
|
{"GET", "/api/fine-tuning/jobs/:id/download", FeatureFineTuning},
|
||||||
|
{"POST", "/api/fine-tuning/datasets", FeatureFineTuning},
|
||||||
}
|
}
|
||||||
|
|
||||||
// FeatureMeta describes a feature for the admin API/UI.
|
// FeatureMeta describes a feature for the admin API/UI.
|
||||||
@@ -104,6 +116,13 @@ func AgentFeatureMetas() []FeatureMeta {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GeneralFeatureMetas returns metadata for general features.
|
||||||
|
func GeneralFeatureMetas() []FeatureMeta {
|
||||||
|
return []FeatureMeta{
|
||||||
|
{FeatureFineTuning, "Fine-Tuning", false},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// APIFeatureMetas returns metadata for API endpoint features.
|
// APIFeatureMetas returns metadata for API endpoint features.
|
||||||
func APIFeatureMetas() []FeatureMeta {
|
func APIFeatureMetas() []FeatureMeta {
|
||||||
return []FeatureMeta{
|
return []FeatureMeta{
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ const (
|
|||||||
FeatureCollections = "collections"
|
FeatureCollections = "collections"
|
||||||
FeatureMCPJobs = "mcp_jobs"
|
FeatureMCPJobs = "mcp_jobs"
|
||||||
|
|
||||||
|
// General features (default OFF for new users)
|
||||||
|
FeatureFineTuning = "fine_tuning"
|
||||||
|
|
||||||
// API features (default ON for new users)
|
// API features (default ON for new users)
|
||||||
FeatureChat = "chat"
|
FeatureChat = "chat"
|
||||||
FeatureImages = "images"
|
FeatureImages = "images"
|
||||||
@@ -52,6 +55,9 @@ const (
|
|||||||
// AgentFeatures lists agent-related features (default OFF).
|
// AgentFeatures lists agent-related features (default OFF).
|
||||||
var AgentFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs}
|
var AgentFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs}
|
||||||
|
|
||||||
|
// GeneralFeatures lists general features (default OFF).
|
||||||
|
var GeneralFeatures = []string{FeatureFineTuning}
|
||||||
|
|
||||||
// APIFeatures lists API endpoint features (default ON).
|
// APIFeatures lists API endpoint features (default ON).
|
||||||
var APIFeatures = []string{
|
var APIFeatures = []string{
|
||||||
FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription,
|
FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription,
|
||||||
@@ -60,7 +66,7 @@ var APIFeatures = []string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AllFeatures lists all known features (used by UI and validation).
|
// AllFeatures lists all known features (used by UI and validation).
|
||||||
var AllFeatures = append(append([]string{}, AgentFeatures...), APIFeatures...)
|
var AllFeatures = append(append(append([]string{}, AgentFeatures...), GeneralFeatures...), APIFeatures...)
|
||||||
|
|
||||||
// defaultOnFeatures is the set of features that default to ON when absent from a user's permission map.
|
// defaultOnFeatures is the set of features that default to ON when absent from a user's permission map.
|
||||||
var defaultOnFeatures = func() map[string]bool {
|
var defaultOnFeatures = func() map[string]bool {
|
||||||
|
|||||||
@@ -80,13 +80,14 @@ func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc {
|
|||||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
defer src.Close()
|
defer src.Close()
|
||||||
if err := svc.UploadToCollectionForUser(userID, name, file.Filename, src); err != nil {
|
key, err := svc.UploadToCollectionForUser(userID, name, file.Filename, src)
|
||||||
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename})
|
return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename, "key": key})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
362
core/http/endpoints/localai/finetune.go
Normal file
362
core/http/endpoints/localai/finetune.go
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/tar"
|
||||||
|
"compress/gzip"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/core/services"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StartFineTuneJobEndpoint starts a new fine-tuning job.
|
||||||
|
func StartFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
userID := getUserID(c)
|
||||||
|
|
||||||
|
var req schema.FineTuneJobRequest
|
||||||
|
if err := c.Bind(&req); err != nil {
|
||||||
|
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||||
|
"error": "Invalid request: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model == "" {
|
||||||
|
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||||
|
"error": "model is required",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if req.DatasetSource == "" {
|
||||||
|
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||||
|
"error": "dataset_source is required",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ftService.StartJob(c.Request().Context(), userID, req)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusCreated, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFineTuneJobsEndpoint lists fine-tuning jobs for the current user.
|
||||||
|
func ListFineTuneJobsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
userID := getUserID(c)
|
||||||
|
jobs := ftService.ListJobs(userID)
|
||||||
|
if jobs == nil {
|
||||||
|
jobs = []*schema.FineTuneJob{}
|
||||||
|
}
|
||||||
|
return c.JSON(http.StatusOK, jobs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFineTuneJobEndpoint gets a specific fine-tuning job.
|
||||||
|
func GetFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
userID := getUserID(c)
|
||||||
|
jobID := c.Param("id")
|
||||||
|
|
||||||
|
job, err := ftService.GetJob(userID, jobID)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusNotFound, map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopFineTuneJobEndpoint stops a running fine-tuning job.
|
||||||
|
func StopFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
userID := getUserID(c)
|
||||||
|
jobID := c.Param("id")
|
||||||
|
|
||||||
|
// Check for save_checkpoint query param
|
||||||
|
saveCheckpoint := c.QueryParam("save_checkpoint") == "true"
|
||||||
|
|
||||||
|
err := ftService.StopJob(c.Request().Context(), userID, jobID, saveCheckpoint)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusNotFound, map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, map[string]string{
|
||||||
|
"status": "stopped",
|
||||||
|
"message": "Fine-tuning job stopped",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFineTuneJobEndpoint deletes a fine-tuning job and its data.
|
||||||
|
func DeleteFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
userID := getUserID(c)
|
||||||
|
jobID := c.Param("id")
|
||||||
|
|
||||||
|
err := ftService.DeleteJob(userID, jobID)
|
||||||
|
if err != nil {
|
||||||
|
status := http.StatusInternalServerError
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
status = http.StatusNotFound
|
||||||
|
} else if strings.Contains(err.Error(), "cannot delete") {
|
||||||
|
status = http.StatusConflict
|
||||||
|
}
|
||||||
|
return c.JSON(status, map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, map[string]string{
|
||||||
|
"status": "deleted",
|
||||||
|
"message": "Fine-tuning job deleted",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FineTuneProgressEndpoint streams progress updates via SSE.
|
||||||
|
func FineTuneProgressEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
userID := getUserID(c)
|
||||||
|
jobID := c.Param("id")
|
||||||
|
|
||||||
|
// Set SSE headers
|
||||||
|
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Response().Header().Set("Connection", "keep-alive")
|
||||||
|
c.Response().WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
err := ftService.StreamProgress(c.Request().Context(), userID, jobID, func(event *schema.FineTuneProgressEvent) {
|
||||||
|
data, err := json.Marshal(event)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprintf(c.Response(), "data: %s\n\n", data)
|
||||||
|
c.Response().Flush()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// If headers already sent, we can't send a JSON error
|
||||||
|
fmt.Fprintf(c.Response(), "data: {\"status\":\"error\",\"message\":%q}\n\n", err.Error())
|
||||||
|
c.Response().Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListCheckpointsEndpoint lists checkpoints for a job.
|
||||||
|
func ListCheckpointsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
userID := getUserID(c)
|
||||||
|
jobID := c.Param("id")
|
||||||
|
|
||||||
|
checkpoints, err := ftService.ListCheckpoints(c.Request().Context(), userID, jobID)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusNotFound, map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, map[string]any{
|
||||||
|
"checkpoints": checkpoints,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportModelEndpoint exports a model from a checkpoint.
|
||||||
|
func ExportModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
userID := getUserID(c)
|
||||||
|
jobID := c.Param("id")
|
||||||
|
|
||||||
|
var req schema.ExportRequest
|
||||||
|
if err := c.Bind(&req); err != nil {
|
||||||
|
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||||
|
"error": "Invalid request: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
modelName, err := ftService.ExportModel(c.Request().Context(), userID, jobID, req)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusAccepted, map[string]string{
|
||||||
|
"status": "exporting",
|
||||||
|
"message": "Export started for model '" + modelName + "'",
|
||||||
|
"model_name": modelName,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownloadExportedModelEndpoint streams the exported model directory as a tar.gz archive.
|
||||||
|
func DownloadExportedModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
userID := getUserID(c)
|
||||||
|
jobID := c.Param("id")
|
||||||
|
|
||||||
|
modelDir, modelName, err := ftService.GetExportedModelPath(userID, jobID)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusNotFound, map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Response().Header().Set("Content-Type", "application/gzip")
|
||||||
|
c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s.tar.gz"`, modelName))
|
||||||
|
c.Response().WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
gw := gzip.NewWriter(c.Response())
|
||||||
|
defer gw.Close()
|
||||||
|
|
||||||
|
tw := tar.NewWriter(gw)
|
||||||
|
defer tw.Close()
|
||||||
|
|
||||||
|
err = filepath.Walk(modelDir, func(path string, info os.FileInfo, walkErr error) error {
|
||||||
|
if walkErr != nil {
|
||||||
|
return walkErr
|
||||||
|
}
|
||||||
|
|
||||||
|
relPath, err := filepath.Rel(modelDir, path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := tar.FileInfoHeader(info, "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
header.Name = filepath.Join(modelName, relPath)
|
||||||
|
|
||||||
|
if err := tw.WriteHeader(header); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(tw, f)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// Headers already sent, can't return JSON error
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFineTuneBackendsEndpoint returns installed backends tagged with "fine-tuning".
|
||||||
|
func ListFineTuneBackendsEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||||
|
"error": "failed to list backends: " + err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type backendInfo struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Tags []string `json:"tags,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []backendInfo
|
||||||
|
for _, b := range backends {
|
||||||
|
if !b.Installed {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hasTag := false
|
||||||
|
for _, t := range b.Tags {
|
||||||
|
if strings.EqualFold(t, "fine-tuning") {
|
||||||
|
hasTag = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasTag {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := b.Name
|
||||||
|
if b.Alias != "" {
|
||||||
|
name = b.Alias
|
||||||
|
}
|
||||||
|
result = append(result, backendInfo{
|
||||||
|
Name: name,
|
||||||
|
Description: b.Description,
|
||||||
|
Tags: b.Tags,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
result = []backendInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadDatasetEndpoint handles dataset file upload.
|
||||||
|
func UploadDatasetEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
file, err := c.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||||
|
"error": "file is required",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
src, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||||
|
"error": "failed to open file",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
defer src.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(src)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||||
|
"error": "failed to read file",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
path, err := ftService.UploadDataset(file.Filename, data)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, map[string]string{
|
||||||
|
"path": path,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -208,6 +208,32 @@
|
|||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.sidebar-section-toggle {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-between;
|
||||||
|
width: 100%;
|
||||||
|
background: none;
|
||||||
|
border: none;
|
||||||
|
cursor: pointer;
|
||||||
|
font-family: inherit;
|
||||||
|
transition: color var(--duration-fast);
|
||||||
|
}
|
||||||
|
|
||||||
|
.sidebar-section-toggle:hover {
|
||||||
|
color: var(--color-text-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.sidebar-section-chevron {
|
||||||
|
font-size: 0.5rem;
|
||||||
|
transition: transform var(--duration-fast);
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.sidebar-section-toggle.open .sidebar-section-chevron {
|
||||||
|
transform: rotate(90deg);
|
||||||
|
}
|
||||||
|
|
||||||
.nav-item {
|
.nav-item {
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
@@ -392,6 +418,10 @@
|
|||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.sidebar.collapsed .sidebar-section-chevron {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
.sidebar.collapsed .nav-item {
|
.sidebar.collapsed .nav-item {
|
||||||
justify-content: center;
|
justify-content: center;
|
||||||
padding: 8px 0;
|
padding: 8px 0;
|
||||||
@@ -588,14 +618,6 @@
|
|||||||
.spinner-md .spinner-ring { width: 24px; height: 24px; }
|
.spinner-md .spinner-ring { width: 24px; height: 24px; }
|
||||||
.spinner-lg .spinner-ring { width: 40px; height: 40px; }
|
.spinner-lg .spinner-ring { width: 40px; height: 40px; }
|
||||||
|
|
||||||
.spinner-logo {
|
|
||||||
animation: pulse 1.2s ease-in-out infinite;
|
|
||||||
object-fit: contain;
|
|
||||||
}
|
|
||||||
.spinner-sm .spinner-logo { width: 16px; height: 16px; }
|
|
||||||
.spinner-md .spinner-logo { width: 24px; height: 24px; }
|
|
||||||
.spinner-lg .spinner-logo { width: 40px; height: 40px; }
|
|
||||||
|
|
||||||
/* Model selector */
|
/* Model selector */
|
||||||
.model-selector {
|
.model-selector {
|
||||||
background: var(--color-bg-tertiary);
|
background: var(--color-bg-tertiary);
|
||||||
@@ -2622,6 +2644,43 @@
|
|||||||
font-size: 0.625rem;
|
font-size: 0.625rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Studio tabs */
|
||||||
|
.studio-tabs {
|
||||||
|
display: flex;
|
||||||
|
gap: 0;
|
||||||
|
border-bottom: 1px solid var(--color-border-subtle);
|
||||||
|
padding: 0 var(--spacing-xl);
|
||||||
|
background: var(--color-bg-primary);
|
||||||
|
position: sticky;
|
||||||
|
top: 0;
|
||||||
|
z-index: 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
.studio-tab {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 6px;
|
||||||
|
background: none;
|
||||||
|
border: none;
|
||||||
|
padding: var(--spacing-sm) var(--spacing-md);
|
||||||
|
font-size: 0.8125rem;
|
||||||
|
font-family: inherit;
|
||||||
|
color: var(--color-text-secondary);
|
||||||
|
cursor: pointer;
|
||||||
|
border-bottom: 2px solid transparent;
|
||||||
|
transition: color var(--duration-fast), border-color var(--duration-fast);
|
||||||
|
}
|
||||||
|
|
||||||
|
.studio-tab:hover {
|
||||||
|
color: var(--color-text-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.studio-tab-active {
|
||||||
|
color: var(--color-primary);
|
||||||
|
border-bottom-color: var(--color-primary);
|
||||||
|
font-weight: 500;
|
||||||
|
}
|
||||||
|
|
||||||
/* Two-column layout for media generation pages */
|
/* Two-column layout for media generation pages */
|
||||||
.media-layout {
|
.media-layout {
|
||||||
display: grid;
|
display: grid;
|
||||||
|
|||||||
@@ -1,22 +1,8 @@
|
|||||||
import { useState } from 'react'
|
|
||||||
import { apiUrl } from '../utils/basePath'
|
|
||||||
|
|
||||||
export default function LoadingSpinner({ size = 'md', className = '' }) {
|
export default function LoadingSpinner({ size = 'md', className = '' }) {
|
||||||
const sizeClass = size === 'sm' ? 'spinner-sm' : size === 'lg' ? 'spinner-lg' : 'spinner-md'
|
const sizeClass = size === 'sm' ? 'spinner-sm' : size === 'lg' ? 'spinner-lg' : 'spinner-md'
|
||||||
const [imgFailed, setImgFailed] = useState(false)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={`spinner ${sizeClass} ${className}`}>
|
<div className={`spinner ${sizeClass} ${className}`}>
|
||||||
{imgFailed ? (
|
<div className="spinner-ring" />
|
||||||
<div className="spinner-ring" />
|
|
||||||
) : (
|
|
||||||
<img
|
|
||||||
src={apiUrl('/static/logo.png')}
|
|
||||||
alt=""
|
|
||||||
className="spinner-logo"
|
|
||||||
onError={() => setImgFailed(true)}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,37 +1,57 @@
|
|||||||
import { useState, useEffect } from 'react'
|
import { useState, useEffect } from 'react'
|
||||||
import { NavLink, useNavigate } from 'react-router-dom'
|
import { NavLink, useNavigate, useLocation } from 'react-router-dom'
|
||||||
import ThemeToggle from './ThemeToggle'
|
import ThemeToggle from './ThemeToggle'
|
||||||
import { useAuth } from '../context/AuthContext'
|
import { useAuth } from '../context/AuthContext'
|
||||||
import { apiUrl } from '../utils/basePath'
|
import { apiUrl } from '../utils/basePath'
|
||||||
|
|
||||||
const COLLAPSED_KEY = 'localai_sidebar_collapsed'
|
const COLLAPSED_KEY = 'localai_sidebar_collapsed'
|
||||||
|
const SECTIONS_KEY = 'localai_sidebar_sections'
|
||||||
|
|
||||||
const mainItems = [
|
const topItems = [
|
||||||
{ path: '/app', icon: 'fas fa-home', label: 'Home' },
|
{ path: '/app', icon: 'fas fa-home', label: 'Home' },
|
||||||
{ path: '/app/models', icon: 'fas fa-download', label: 'Install Models', adminOnly: true },
|
{ path: '/app/models', icon: 'fas fa-download', label: 'Install Models', adminOnly: true },
|
||||||
{ path: '/app/chat', icon: 'fas fa-comments', label: 'Chat' },
|
{ path: '/app/chat', icon: 'fas fa-comments', label: 'Chat' },
|
||||||
{ path: '/app/image', icon: 'fas fa-image', label: 'Images' },
|
{ path: '/app/studio', icon: 'fas fa-palette', label: 'Studio' },
|
||||||
{ path: '/app/video', icon: 'fas fa-video', label: 'Video' },
|
|
||||||
{ path: '/app/tts', icon: 'fas fa-music', label: 'TTS' },
|
|
||||||
{ path: '/app/sound', icon: 'fas fa-volume-high', label: 'Sound' },
|
|
||||||
{ path: '/app/talk', icon: 'fas fa-phone', label: 'Talk' },
|
{ path: '/app/talk', icon: 'fas fa-phone', label: 'Talk' },
|
||||||
{ path: '/app/usage', icon: 'fas fa-chart-bar', label: 'Usage', authOnly: true },
|
|
||||||
]
|
]
|
||||||
|
|
||||||
const agentItems = [
|
const sections = [
|
||||||
{ path: '/app/agents', icon: 'fas fa-robot', label: 'Agents' },
|
{
|
||||||
{ path: '/app/skills', icon: 'fas fa-wand-magic-sparkles', label: 'Skills' },
|
id: 'tools',
|
||||||
{ path: '/app/collections', icon: 'fas fa-database', label: 'Memory' },
|
title: 'Tools',
|
||||||
{ path: '/app/agent-jobs', icon: 'fas fa-tasks', label: 'MCP CI Jobs', feature: 'mcp' },
|
items: [
|
||||||
]
|
{ path: '/app/fine-tune', icon: 'fas fa-graduation-cap', label: 'Fine-Tune', feature: 'fine_tuning' },
|
||||||
|
],
|
||||||
const systemItems = [
|
},
|
||||||
{ path: '/app/users', icon: 'fas fa-users', label: 'Users', adminOnly: true, authOnly: true },
|
{
|
||||||
{ path: '/app/backends', icon: 'fas fa-server', label: 'Backends', adminOnly: true },
|
id: 'agents',
|
||||||
{ path: '/app/traces', icon: 'fas fa-chart-line', label: 'Traces', adminOnly: true },
|
title: 'Agents',
|
||||||
{ path: '/app/p2p', icon: 'fas fa-circle-nodes', label: 'Swarm', adminOnly: true },
|
featureMap: {
|
||||||
{ path: '/app/manage', icon: 'fas fa-desktop', label: 'System', adminOnly: true },
|
'/app/agents': 'agents',
|
||||||
{ path: '/app/settings', icon: 'fas fa-cog', label: 'Settings', adminOnly: true },
|
'/app/skills': 'skills',
|
||||||
|
'/app/collections': 'collections',
|
||||||
|
'/app/agent-jobs': 'mcp_jobs',
|
||||||
|
},
|
||||||
|
items: [
|
||||||
|
{ path: '/app/agents', icon: 'fas fa-robot', label: 'Agents' },
|
||||||
|
{ path: '/app/skills', icon: 'fas fa-wand-magic-sparkles', label: 'Skills' },
|
||||||
|
{ path: '/app/collections', icon: 'fas fa-database', label: 'Memory' },
|
||||||
|
{ path: '/app/agent-jobs', icon: 'fas fa-tasks', label: 'MCP CI Jobs', feature: 'mcp' },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'system',
|
||||||
|
title: 'System',
|
||||||
|
items: [
|
||||||
|
{ path: '/app/usage', icon: 'fas fa-chart-bar', label: 'Usage', authOnly: true },
|
||||||
|
{ path: '/app/users', icon: 'fas fa-users', label: 'Users', adminOnly: true, authOnly: true },
|
||||||
|
{ path: '/app/backends', icon: 'fas fa-server', label: 'Backends', adminOnly: true },
|
||||||
|
{ path: '/app/traces', icon: 'fas fa-chart-line', label: 'Traces', adminOnly: true },
|
||||||
|
{ path: '/app/p2p', icon: 'fas fa-circle-nodes', label: 'Swarm', adminOnly: true },
|
||||||
|
{ path: '/app/manage', icon: 'fas fa-desktop', label: 'System', adminOnly: true },
|
||||||
|
{ path: '/app/settings', icon: 'fas fa-cog', label: 'Settings', adminOnly: true },
|
||||||
|
],
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
function NavItem({ item, onClose, collapsed }) {
|
function NavItem({ item, onClose, collapsed }) {
|
||||||
@@ -51,18 +71,47 @@ function NavItem({ item, onClose, collapsed }) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function loadSectionState() {
|
||||||
|
try {
|
||||||
|
const stored = localStorage.getItem(SECTIONS_KEY)
|
||||||
|
return stored ? JSON.parse(stored) : {}
|
||||||
|
} catch (_) {
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function saveSectionState(state) {
|
||||||
|
try { localStorage.setItem(SECTIONS_KEY, JSON.stringify(state)) } catch (_) { /* ignore */ }
|
||||||
|
}
|
||||||
|
|
||||||
export default function Sidebar({ isOpen, onClose }) {
|
export default function Sidebar({ isOpen, onClose }) {
|
||||||
const [features, setFeatures] = useState({})
|
const [features, setFeatures] = useState({})
|
||||||
const [collapsed, setCollapsed] = useState(() => {
|
const [collapsed, setCollapsed] = useState(() => {
|
||||||
try { return localStorage.getItem(COLLAPSED_KEY) === 'true' } catch (_) { return false }
|
try { return localStorage.getItem(COLLAPSED_KEY) === 'true' } catch (_) { return false }
|
||||||
})
|
})
|
||||||
|
const [openSections, setOpenSections] = useState(loadSectionState)
|
||||||
const { isAdmin, authEnabled, user, logout, hasFeature } = useAuth()
|
const { isAdmin, authEnabled, user, logout, hasFeature } = useAuth()
|
||||||
const navigate = useNavigate()
|
const navigate = useNavigate()
|
||||||
|
const location = useLocation()
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
fetch(apiUrl('/api/features')).then(r => r.json()).then(setFeatures).catch(() => {})
|
fetch(apiUrl('/api/features')).then(r => r.json()).then(setFeatures).catch(() => {})
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
|
// Auto-expand section containing the active route
|
||||||
|
useEffect(() => {
|
||||||
|
for (const section of sections) {
|
||||||
|
const match = section.items.some(item => location.pathname.startsWith(item.path))
|
||||||
|
if (match && !openSections[section.id]) {
|
||||||
|
setOpenSections(prev => {
|
||||||
|
const next = { ...prev, [section.id]: true }
|
||||||
|
saveSectionState(next)
|
||||||
|
return next
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [location.pathname])
|
||||||
|
|
||||||
const toggleCollapse = () => {
|
const toggleCollapse = () => {
|
||||||
setCollapsed(prev => {
|
setCollapsed(prev => {
|
||||||
const next = !prev
|
const next = !prev
|
||||||
@@ -72,17 +121,34 @@ export default function Sidebar({ isOpen, onClose }) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const visibleMainItems = mainItems.filter(item => {
|
const toggleSection = (id) => {
|
||||||
if (item.adminOnly && !isAdmin) return false
|
setOpenSections(prev => {
|
||||||
if (item.authOnly && !authEnabled) return false
|
const next = { ...prev, [id]: !prev[id] }
|
||||||
return true
|
saveSectionState(next)
|
||||||
})
|
return next
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
const visibleSystemItems = systemItems.filter(item => {
|
const filterItem = (item) => {
|
||||||
if (item.adminOnly && !isAdmin) return false
|
if (item.adminOnly && !isAdmin) return false
|
||||||
if (item.authOnly && !authEnabled) return false
|
if (item.authOnly && !authEnabled) return false
|
||||||
|
if (item.feature && features[item.feature] === false) return false
|
||||||
|
if (item.feature && !hasFeature(item.feature)) return false
|
||||||
return true
|
return true
|
||||||
})
|
}
|
||||||
|
|
||||||
|
const visibleTopItems = topItems.filter(filterItem)
|
||||||
|
|
||||||
|
const getVisibleSectionItems = (section) => {
|
||||||
|
return section.items.filter(item => {
|
||||||
|
if (!filterItem(item)) return false
|
||||||
|
if (section.featureMap) {
|
||||||
|
const featureName = section.featureMap[item.path]
|
||||||
|
return featureName ? hasFeature(featureName) : isAdmin
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@@ -104,57 +170,57 @@ export default function Sidebar({ isOpen, onClose }) {
|
|||||||
|
|
||||||
{/* Navigation */}
|
{/* Navigation */}
|
||||||
<nav className="sidebar-nav">
|
<nav className="sidebar-nav">
|
||||||
{/* Main section */}
|
{/* Top-level items */}
|
||||||
<div className="sidebar-section">
|
<div className="sidebar-section">
|
||||||
{visibleMainItems.map(item => (
|
{visibleTopItems.map(item => (
|
||||||
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
|
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Agents section (per-feature permissions) */}
|
{/* Collapsible sections */}
|
||||||
{features.agents !== false && (() => {
|
{sections.map(section => {
|
||||||
const featureMap = {
|
// For agents section, check global feature flag
|
||||||
'/app/agents': 'agents',
|
if (section.id === 'agents' && features.agents === false) return null
|
||||||
'/app/skills': 'skills',
|
|
||||||
'/app/collections': 'collections',
|
const visibleItems = getVisibleSectionItems(section)
|
||||||
'/app/agent-jobs': 'mcp_jobs',
|
if (visibleItems.length === 0) return null
|
||||||
}
|
|
||||||
const visibleAgentItems = agentItems.filter(item => {
|
const isSectionOpen = openSections[section.id]
|
||||||
if (item.feature && features[item.feature] === false) return false
|
const showItems = isSectionOpen || collapsed
|
||||||
const featureName = featureMap[item.path]
|
|
||||||
return featureName ? hasFeature(featureName) : isAdmin
|
|
||||||
})
|
|
||||||
if (visibleAgentItems.length === 0) return null
|
|
||||||
return (
|
return (
|
||||||
<div className="sidebar-section">
|
<div key={section.id} className="sidebar-section">
|
||||||
<div className="sidebar-section-title">Agents</div>
|
<button
|
||||||
{visibleAgentItems.map(item => (
|
className={`sidebar-section-title sidebar-section-toggle ${isSectionOpen ? 'open' : ''}`}
|
||||||
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
|
onClick={() => toggleSection(section.id)}
|
||||||
))}
|
title={collapsed ? section.title : undefined}
|
||||||
|
>
|
||||||
|
<span>{section.title}</span>
|
||||||
|
<i className="fas fa-chevron-right sidebar-section-chevron" />
|
||||||
|
</button>
|
||||||
|
{showItems && (
|
||||||
|
<div className="sidebar-section-items">
|
||||||
|
{section.id === 'system' && (
|
||||||
|
<a
|
||||||
|
href={apiUrl('/swagger/index.html')}
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
className="nav-item"
|
||||||
|
title={collapsed ? 'API' : undefined}
|
||||||
|
>
|
||||||
|
<i className="fas fa-code nav-icon" />
|
||||||
|
<span className="nav-label">API</span>
|
||||||
|
<i className="fas fa-external-link-alt nav-external" />
|
||||||
|
</a>
|
||||||
|
)}
|
||||||
|
{visibleItems.map(item => (
|
||||||
|
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
})()}
|
})}
|
||||||
|
|
||||||
{/* System section */}
|
|
||||||
<div className="sidebar-section">
|
|
||||||
{visibleSystemItems.length > 0 && (
|
|
||||||
<div className="sidebar-section-title">System</div>
|
|
||||||
)}
|
|
||||||
<a
|
|
||||||
href={apiUrl('/swagger/index.html')}
|
|
||||||
target="_blank"
|
|
||||||
rel="noopener noreferrer"
|
|
||||||
className="nav-item"
|
|
||||||
title={collapsed ? 'API' : undefined}
|
|
||||||
>
|
|
||||||
<i className="fas fa-code nav-icon" />
|
|
||||||
<span className="nav-label">API</span>
|
|
||||||
<i className="fas fa-external-link-alt nav-external" />
|
|
||||||
</a>
|
|
||||||
{visibleSystemItems.map(item => (
|
|
||||||
<NavItem key={item.path} item={item} onClose={onClose} collapsed={collapsed} />
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</nav>
|
</nav>
|
||||||
|
|
||||||
{/* Footer */}
|
{/* Footer */}
|
||||||
|
|||||||
1525
core/http/react-ui/src/pages/FineTune.jsx
Normal file
1525
core/http/react-ui/src/pages/FineTune.jsx
Normal file
File diff suppressed because it is too large
Load Diff
48
core/http/react-ui/src/pages/Studio.jsx
Normal file
48
core/http/react-ui/src/pages/Studio.jsx
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import { useSearchParams } from 'react-router-dom'
|
||||||
|
import ImageGen from './ImageGen'
|
||||||
|
import VideoGen from './VideoGen'
|
||||||
|
import TTS from './TTS'
|
||||||
|
import Sound from './Sound'
|
||||||
|
|
||||||
|
const TABS = [
|
||||||
|
{ key: 'images', label: 'Images', icon: 'fas fa-image' },
|
||||||
|
{ key: 'video', label: 'Video', icon: 'fas fa-video' },
|
||||||
|
{ key: 'tts', label: 'TTS', icon: 'fas fa-headphones' },
|
||||||
|
{ key: 'sound', label: 'Sound', icon: 'fas fa-music' },
|
||||||
|
]
|
||||||
|
|
||||||
|
const TAB_COMPONENTS = {
|
||||||
|
images: ImageGen,
|
||||||
|
video: VideoGen,
|
||||||
|
tts: TTS,
|
||||||
|
sound: Sound,
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function Studio() {
|
||||||
|
const [searchParams, setSearchParams] = useSearchParams()
|
||||||
|
const activeTab = searchParams.get('tab') || 'images'
|
||||||
|
|
||||||
|
const setTab = (key) => {
|
||||||
|
setSearchParams({ tab: key }, { replace: true })
|
||||||
|
}
|
||||||
|
|
||||||
|
const ActiveComponent = TAB_COMPONENTS[activeTab] || ImageGen
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<div className="studio-tabs">
|
||||||
|
{TABS.map(tab => (
|
||||||
|
<button
|
||||||
|
key={tab.key}
|
||||||
|
className={`studio-tab${activeTab === tab.key ? ' studio-tab-active' : ''}`}
|
||||||
|
onClick={() => setTab(tab.key)}
|
||||||
|
>
|
||||||
|
<i className={tab.icon} />
|
||||||
|
<span>{tab.label}</span>
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
<ActiveComponent />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -45,9 +45,11 @@ function PermissionSummary({ user, onClick }) {
|
|||||||
const perms = user.permissions || {}
|
const perms = user.permissions || {}
|
||||||
const apiFeatures = ['chat', 'images', 'audio_speech', 'audio_transcription', 'vad', 'detection', 'video', 'embeddings', 'sound']
|
const apiFeatures = ['chat', 'images', 'audio_speech', 'audio_transcription', 'vad', 'detection', 'video', 'embeddings', 'sound']
|
||||||
const agentFeatures = ['agents', 'skills', 'collections', 'mcp_jobs']
|
const agentFeatures = ['agents', 'skills', 'collections', 'mcp_jobs']
|
||||||
|
const generalFeatures = ['fine_tuning']
|
||||||
|
|
||||||
const apiOn = apiFeatures.filter(f => perms[f] !== false && (perms[f] === true || perms[f] === undefined)).length
|
const apiOn = apiFeatures.filter(f => perms[f] !== false && (perms[f] === true || perms[f] === undefined)).length
|
||||||
const agentOn = agentFeatures.filter(f => perms[f]).length
|
const agentOn = agentFeatures.filter(f => perms[f]).length
|
||||||
|
const generalOn = generalFeatures.filter(f => perms[f]).length
|
||||||
|
|
||||||
const modelRestricted = user.allowed_models?.enabled
|
const modelRestricted = user.allowed_models?.enabled
|
||||||
|
|
||||||
@@ -58,7 +60,7 @@ function PermissionSummary({ user, onClick }) {
|
|||||||
title="Edit permissions"
|
title="Edit permissions"
|
||||||
>
|
>
|
||||||
<i className="fas fa-shield-halved" />
|
<i className="fas fa-shield-halved" />
|
||||||
{apiOn}/{apiFeatures.length} API, {agentOn}/{agentFeatures.length} Agent
|
{apiOn}/{apiFeatures.length} API, {agentOn}/{agentFeatures.length} Agent, {generalOn}/{generalFeatures.length} Features
|
||||||
{modelRestricted && ' | Models restricted'}
|
{modelRestricted && ' | Models restricted'}
|
||||||
</button>
|
</button>
|
||||||
)
|
)
|
||||||
@@ -71,6 +73,7 @@ function PermissionsModal({ user, featureMeta, availableModels, onClose, onSave,
|
|||||||
|
|
||||||
const apiFeatures = featureMeta?.api_features || []
|
const apiFeatures = featureMeta?.api_features || []
|
||||||
const agentFeatures = featureMeta?.agent_features || []
|
const agentFeatures = featureMeta?.agent_features || []
|
||||||
|
const generalFeatures = featureMeta?.general_features || []
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const handleKeyDown = (e) => {
|
const handleKeyDown = (e) => {
|
||||||
@@ -189,6 +192,33 @@ function PermissionsModal({ user, featureMeta, availableModels, onClose, onSave,
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* General Features */}
|
||||||
|
{generalFeatures.length > 0 && (
|
||||||
|
<div className="perm-section">
|
||||||
|
<div className="perm-section-header">
|
||||||
|
<strong className="perm-section-title">
|
||||||
|
<i className="fas fa-sliders" />
|
||||||
|
Features
|
||||||
|
</strong>
|
||||||
|
<div className="action-group">
|
||||||
|
<button className="btn btn-sm btn-secondary perm-btn-all-none" onClick={() => setAllFeatures(generalFeatures, true)}>All</button>
|
||||||
|
<button className="btn btn-sm btn-secondary perm-btn-all-none" onClick={() => setAllFeatures(generalFeatures, false)}>None</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="perm-grid">
|
||||||
|
{generalFeatures.map(f => (
|
||||||
|
<button
|
||||||
|
key={f.key}
|
||||||
|
className={`btn btn-sm ${permissions[f.key] ? 'btn-primary' : 'btn-secondary'} perm-btn-feature`}
|
||||||
|
onClick={() => toggleFeature(f.key)}
|
||||||
|
>
|
||||||
|
{f.label}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Model Access */}
|
{/* Model Access */}
|
||||||
<div className="perm-section">
|
<div className="perm-section">
|
||||||
<div className="perm-section-header">
|
<div className="perm-section-header">
|
||||||
@@ -510,6 +540,9 @@ export default function Users() {
|
|||||||
{ key: 'collections', label: 'Collections', default: false },
|
{ key: 'collections', label: 'Collections', default: false },
|
||||||
{ key: 'mcp_jobs', label: 'MCP CI Jobs', default: false },
|
{ key: 'mcp_jobs', label: 'MCP CI Jobs', default: false },
|
||||||
],
|
],
|
||||||
|
general_features: [
|
||||||
|
{ key: 'fine_tuning', label: 'Fine-Tuning', default: false },
|
||||||
|
],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}, [])
|
}, [])
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ import ImportModel from './pages/ImportModel'
|
|||||||
import BackendLogs from './pages/BackendLogs'
|
import BackendLogs from './pages/BackendLogs'
|
||||||
import Explorer from './pages/Explorer'
|
import Explorer from './pages/Explorer'
|
||||||
import Login from './pages/Login'
|
import Login from './pages/Login'
|
||||||
|
import FineTune from './pages/FineTune'
|
||||||
|
import Studio from './pages/Studio'
|
||||||
import NotFound from './pages/NotFound'
|
import NotFound from './pages/NotFound'
|
||||||
import Usage from './pages/Usage'
|
import Usage from './pages/Usage'
|
||||||
import Users from './pages/Users'
|
import Users from './pages/Users'
|
||||||
@@ -43,6 +45,7 @@ function BrowseRedirect() {
|
|||||||
return <Navigate to={`/app/${splat || ''}`} replace />
|
return <Navigate to={`/app/${splat || ''}`} replace />
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
function Admin({ children }) {
|
function Admin({ children }) {
|
||||||
return <RequireAdmin>{children}</RequireAdmin>
|
return <RequireAdmin>{children}</RequireAdmin>
|
||||||
}
|
}
|
||||||
@@ -64,6 +67,7 @@ const appChildren = [
|
|||||||
{ path: 'tts/:model', element: <TTS /> },
|
{ path: 'tts/:model', element: <TTS /> },
|
||||||
{ path: 'sound', element: <Sound /> },
|
{ path: 'sound', element: <Sound /> },
|
||||||
{ path: 'sound/:model', element: <Sound /> },
|
{ path: 'sound/:model', element: <Sound /> },
|
||||||
|
{ path: 'studio', element: <Studio /> },
|
||||||
{ path: 'talk', element: <Talk /> },
|
{ path: 'talk', element: <Talk /> },
|
||||||
{ path: 'usage', element: <Usage /> },
|
{ path: 'usage', element: <Usage /> },
|
||||||
{ path: 'account', element: <Account /> },
|
{ path: 'account', element: <Account /> },
|
||||||
@@ -89,6 +93,7 @@ const appChildren = [
|
|||||||
{ path: 'agent-jobs/tasks/:id', element: <Feature feature="mcp_jobs"><AgentTaskDetails /></Feature> },
|
{ path: 'agent-jobs/tasks/:id', element: <Feature feature="mcp_jobs"><AgentTaskDetails /></Feature> },
|
||||||
{ path: 'agent-jobs/tasks/:id/edit', element: <Feature feature="mcp_jobs"><AgentTaskDetails /></Feature> },
|
{ path: 'agent-jobs/tasks/:id/edit', element: <Feature feature="mcp_jobs"><AgentTaskDetails /></Feature> },
|
||||||
{ path: 'agent-jobs/jobs/:id', element: <Feature feature="mcp_jobs"><AgentJobDetails /></Feature> },
|
{ path: 'agent-jobs/jobs/:id', element: <Feature feature="mcp_jobs"><AgentJobDetails /></Feature> },
|
||||||
|
{ path: 'fine-tune', element: <Feature feature="fine_tuning"><FineTune /></Feature> },
|
||||||
{ path: 'model-editor/:name', element: <Admin><ModelEditor /></Admin> },
|
{ path: 'model-editor/:name', element: <Admin><ModelEditor /></Admin> },
|
||||||
{ path: 'import-model', element: <Admin><ImportModel /></Admin> },
|
{ path: 'import-model', element: <Admin><ImportModel /></Admin> },
|
||||||
{ path: '*', element: <NotFound /> },
|
{ path: '*', element: <NotFound /> },
|
||||||
|
|||||||
19
core/http/react-ui/src/utils/api.js
vendored
19
core/http/react-ui/src/utils/api.js
vendored
@@ -380,6 +380,25 @@ export const apiKeysApi = {
|
|||||||
revoke: (id) => fetchJSON(`/api/auth/api-keys/${encodeURIComponent(id)}`, { method: 'DELETE' }),
|
revoke: (id) => fetchJSON(`/api/auth/api-keys/${encodeURIComponent(id)}`, { method: 'DELETE' }),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fine-tuning API
|
||||||
|
export const fineTuneApi = {
|
||||||
|
listBackends: () => fetchJSON('/api/fine-tuning/backends'),
|
||||||
|
startJob: (data) => postJSON('/api/fine-tuning/jobs', data),
|
||||||
|
listJobs: () => fetchJSON('/api/fine-tuning/jobs'),
|
||||||
|
getJob: (id) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}`),
|
||||||
|
stopJob: (id, saveCheckpoint) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}/stop?save_checkpoint=${saveCheckpoint ? 'true' : 'false'}`, { method: 'POST' }),
|
||||||
|
deleteJob: (id) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}`, { method: 'DELETE' }),
|
||||||
|
listCheckpoints: (id) => fetchJSON(`/api/fine-tuning/jobs/${enc(id)}/checkpoints`),
|
||||||
|
exportModel: (id, data) => postJSON(`/api/fine-tuning/jobs/${enc(id)}/export`, data),
|
||||||
|
uploadDataset: (file) => {
|
||||||
|
const formData = new FormData()
|
||||||
|
formData.append('file', file)
|
||||||
|
return fetch(apiUrl('/api/fine-tuning/datasets'), { method: 'POST', body: formData }).then(handleResponse)
|
||||||
|
},
|
||||||
|
progressUrl: (id) => apiUrl(`/api/fine-tuning/jobs/${enc(id)}/progress`),
|
||||||
|
downloadUrl: (id) => apiUrl(`/api/fine-tuning/jobs/${enc(id)}/download`),
|
||||||
|
}
|
||||||
|
|
||||||
// File to base64 helper
|
// File to base64 helper
|
||||||
export function fileToBase64(file) {
|
export function fileToBase64(file) {
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
|
|||||||
@@ -777,9 +777,10 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||||
"agent_features": auth.AgentFeatureMetas(),
|
"agent_features": auth.AgentFeatureMetas(),
|
||||||
"api_features": auth.APIFeatureMetas(),
|
"general_features": auth.GeneralFeatureMetas(),
|
||||||
"models": modelNames,
|
"api_features": auth.APIFeatureMetas(),
|
||||||
|
"models": modelNames,
|
||||||
})
|
})
|
||||||
}, adminMw)
|
}, adminMw)
|
||||||
|
|
||||||
|
|||||||
42
core/http/routes/finetuning.go
Normal file
42
core/http/routes/finetuning.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package routes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||||
|
"github.com/mudler/LocalAI/core/services"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterFineTuningRoutes registers fine-tuning API routes.
|
||||||
|
func RegisterFineTuningRoutes(e *echo.Echo, ftService *services.FineTuneService, appConfig *config.ApplicationConfig, fineTuningMw echo.MiddlewareFunc) {
|
||||||
|
if ftService == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Service readiness middleware
|
||||||
|
readyMw := func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
if ftService == nil {
|
||||||
|
return c.JSON(http.StatusServiceUnavailable, map[string]string{
|
||||||
|
"error": "fine-tuning service is not available",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ft := e.Group("/api/fine-tuning", readyMw, fineTuningMw)
|
||||||
|
ft.GET("/backends", localai.ListFineTuneBackendsEndpoint(appConfig))
|
||||||
|
ft.POST("/jobs", localai.StartFineTuneJobEndpoint(ftService))
|
||||||
|
ft.GET("/jobs", localai.ListFineTuneJobsEndpoint(ftService))
|
||||||
|
ft.GET("/jobs/:id", localai.GetFineTuneJobEndpoint(ftService))
|
||||||
|
ft.POST("/jobs/:id/stop", localai.StopFineTuneJobEndpoint(ftService))
|
||||||
|
ft.DELETE("/jobs/:id", localai.DeleteFineTuneJobEndpoint(ftService))
|
||||||
|
ft.GET("/jobs/:id/progress", localai.FineTuneProgressEndpoint(ftService))
|
||||||
|
ft.GET("/jobs/:id/checkpoints", localai.ListCheckpointsEndpoint(ftService))
|
||||||
|
ft.POST("/jobs/:id/export", localai.ExportModelEndpoint(ftService))
|
||||||
|
ft.GET("/jobs/:id/download", localai.DownloadExportedModelEndpoint(ftService))
|
||||||
|
ft.POST("/datasets", localai.UploadDatasetEndpoint(ftService))
|
||||||
|
}
|
||||||
@@ -134,8 +134,9 @@ func RegisterLocalAIRoutes(router *echo.Echo,
|
|||||||
|
|
||||||
router.GET("/api/features", func(c echo.Context) error {
|
router.GET("/api/features", func(c echo.Context) error {
|
||||||
return c.JSON(200, map[string]bool{
|
return c.JSON(200, map[string]bool{
|
||||||
"agents": appConfig.AgentPool.Enabled,
|
"agents": appConfig.AgentPool.Enabled,
|
||||||
"mcp": !appConfig.DisableMCP,
|
"mcp": !appConfig.DisableMCP,
|
||||||
|
"fine_tuning": appConfig.FineTuning.Enabled,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
111
core/schema/finetune.go
Normal file
111
core/schema/finetune.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
// RewardFunctionSpec defines a reward function for GRPO training.
|
||||||
|
type RewardFunctionSpec struct {
|
||||||
|
Type string `json:"type"` // "builtin" or "inline"
|
||||||
|
Name string `json:"name"`
|
||||||
|
Code string `json:"code,omitempty"` // inline only
|
||||||
|
Params map[string]string `json:"params,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FineTuneJobRequest is the REST API request to start a fine-tuning job.
|
||||||
|
type FineTuneJobRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Backend string `json:"backend"` // "trl"
|
||||||
|
TrainingType string `json:"training_type,omitempty"` // lora, loha, lokr, full
|
||||||
|
TrainingMethod string `json:"training_method,omitempty"` // sft, dpo, grpo, rloo, reward, kto, orpo
|
||||||
|
|
||||||
|
// Adapter config
|
||||||
|
AdapterRank int32 `json:"adapter_rank,omitempty"`
|
||||||
|
AdapterAlpha int32 `json:"adapter_alpha,omitempty"`
|
||||||
|
AdapterDropout float32 `json:"adapter_dropout,omitempty"`
|
||||||
|
TargetModules []string `json:"target_modules,omitempty"`
|
||||||
|
|
||||||
|
// Training hyperparameters
|
||||||
|
LearningRate float32 `json:"learning_rate,omitempty"`
|
||||||
|
NumEpochs int32 `json:"num_epochs,omitempty"`
|
||||||
|
BatchSize int32 `json:"batch_size,omitempty"`
|
||||||
|
GradientAccumulationSteps int32 `json:"gradient_accumulation_steps,omitempty"`
|
||||||
|
WarmupSteps int32 `json:"warmup_steps,omitempty"`
|
||||||
|
MaxSteps int32 `json:"max_steps,omitempty"`
|
||||||
|
SaveSteps int32 `json:"save_steps,omitempty"`
|
||||||
|
WeightDecay float32 `json:"weight_decay,omitempty"`
|
||||||
|
GradientCheckpointing bool `json:"gradient_checkpointing,omitempty"`
|
||||||
|
Optimizer string `json:"optimizer,omitempty"`
|
||||||
|
Seed int32 `json:"seed,omitempty"`
|
||||||
|
MixedPrecision string `json:"mixed_precision,omitempty"`
|
||||||
|
|
||||||
|
// Dataset
|
||||||
|
DatasetSource string `json:"dataset_source"`
|
||||||
|
DatasetSplit string `json:"dataset_split,omitempty"`
|
||||||
|
|
||||||
|
// Resume from a checkpoint
|
||||||
|
ResumeFromCheckpoint string `json:"resume_from_checkpoint,omitempty"`
|
||||||
|
|
||||||
|
// GRPO reward functions
|
||||||
|
RewardFunctions []RewardFunctionSpec `json:"reward_functions,omitempty"`
|
||||||
|
|
||||||
|
// Backend-specific and method-specific options
|
||||||
|
ExtraOptions map[string]string `json:"extra_options,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FineTuneJob represents a fine-tuning job with its current state.
|
||||||
|
type FineTuneJob struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID string `json:"user_id,omitempty"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Backend string `json:"backend"`
|
||||||
|
ModelID string `json:"model_id,omitempty"` // backend model loader ID
|
||||||
|
TrainingType string `json:"training_type"`
|
||||||
|
TrainingMethod string `json:"training_method"`
|
||||||
|
Status string `json:"status"` // queued, loading_model, loading_dataset, training, saving, completed, failed, stopped
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
OutputDir string `json:"output_dir"`
|
||||||
|
ExtraOptions map[string]string `json:"extra_options,omitempty"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
|
||||||
|
// Export state (tracked separately from training status)
|
||||||
|
ExportStatus string `json:"export_status,omitempty"` // "", "exporting", "completed", "failed"
|
||||||
|
ExportMessage string `json:"export_message,omitempty"`
|
||||||
|
ExportModelName string `json:"export_model_name,omitempty"` // registered model name after export
|
||||||
|
|
||||||
|
// Full config for resume/reuse
|
||||||
|
Config *FineTuneJobRequest `json:"config,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FineTuneJobResponse is the REST API response when creating a job.
|
||||||
|
type FineTuneJobResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FineTuneProgressEvent is an SSE event for training progress.
|
||||||
|
type FineTuneProgressEvent struct {
|
||||||
|
JobID string `json:"job_id"`
|
||||||
|
CurrentStep int32 `json:"current_step"`
|
||||||
|
TotalSteps int32 `json:"total_steps"`
|
||||||
|
CurrentEpoch float32 `json:"current_epoch"`
|
||||||
|
TotalEpochs float32 `json:"total_epochs"`
|
||||||
|
Loss float32 `json:"loss"`
|
||||||
|
LearningRate float32 `json:"learning_rate"`
|
||||||
|
GradNorm float32 `json:"grad_norm"`
|
||||||
|
EvalLoss float32 `json:"eval_loss"`
|
||||||
|
EtaSeconds float32 `json:"eta_seconds"`
|
||||||
|
ProgressPercent float32 `json:"progress_percent"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
CheckpointPath string `json:"checkpoint_path,omitempty"`
|
||||||
|
SamplePath string `json:"sample_path,omitempty"`
|
||||||
|
ExtraMetrics map[string]float32 `json:"extra_metrics,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportRequest is the REST API request to export a model.
|
||||||
|
type ExportRequest struct {
|
||||||
|
Name string `json:"name,omitempty"` // model name for LocalAI (auto-generated if empty)
|
||||||
|
CheckpointPath string `json:"checkpoint_path"`
|
||||||
|
ExportFormat string `json:"export_format"` // lora, merged_16bit, merged_4bit, gguf
|
||||||
|
QuantizationMethod string `json:"quantization_method"` // for GGUF: q4_k_m, q5_k_m, q8_0, f16
|
||||||
|
Model string `json:"model,omitempty"` // base model name for merge
|
||||||
|
ExtraOptions map[string]string `json:"extra_options,omitempty"`
|
||||||
|
}
|
||||||
@@ -1042,7 +1042,7 @@ func (s *AgentPoolService) CreateCollection(name string) error {
|
|||||||
return s.collectionsBackend.CreateCollection(name)
|
return s.collectionsBackend.CreateCollection(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AgentPoolService) UploadToCollection(collection, filename string, fileBody io.Reader) error {
|
func (s *AgentPoolService) UploadToCollection(collection, filename string, fileBody io.Reader) (string, error) {
|
||||||
return s.collectionsBackend.Upload(collection, filename, fileBody)
|
return s.collectionsBackend.Upload(collection, filename, fileBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1554,10 +1554,10 @@ func (s *AgentPoolService) CreateCollectionForUser(userID, name string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UploadToCollectionForUser uploads to a collection for a specific user.
|
// UploadToCollectionForUser uploads to a collection for a specific user.
|
||||||
func (s *AgentPoolService) UploadToCollectionForUser(userID, collection, filename string, fileBody io.Reader) error {
|
func (s *AgentPoolService) UploadToCollectionForUser(userID, collection, filename string, fileBody io.Reader) (string, error) {
|
||||||
backend, err := s.CollectionsBackendForUser(userID)
|
backend, err := s.CollectionsBackendForUser(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
return backend.Upload(collection, filename, fileBody)
|
return backend.Upload(collection, filename, fileBody)
|
||||||
}
|
}
|
||||||
|
|||||||
700
core/services/finetune.go
Normal file
700
core/services/finetune.go
Normal file
@@ -0,0 +1,700 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
"github.com/mudler/xlog"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FineTuneService manages fine-tuning jobs and their lifecycle.
|
||||||
|
type FineTuneService struct {
|
||||||
|
appConfig *config.ApplicationConfig
|
||||||
|
modelLoader *model.ModelLoader
|
||||||
|
configLoader *config.ModelConfigLoader
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
jobs map[string]*schema.FineTuneJob
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFineTuneService creates a new FineTuneService.
|
||||||
|
func NewFineTuneService(
|
||||||
|
appConfig *config.ApplicationConfig,
|
||||||
|
modelLoader *model.ModelLoader,
|
||||||
|
configLoader *config.ModelConfigLoader,
|
||||||
|
) *FineTuneService {
|
||||||
|
s := &FineTuneService{
|
||||||
|
appConfig: appConfig,
|
||||||
|
modelLoader: modelLoader,
|
||||||
|
configLoader: configLoader,
|
||||||
|
jobs: make(map[string]*schema.FineTuneJob),
|
||||||
|
}
|
||||||
|
s.loadAllJobs()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// fineTuneBaseDir returns the base directory for fine-tune job data.
|
||||||
|
func (s *FineTuneService) fineTuneBaseDir() string {
|
||||||
|
return filepath.Join(s.appConfig.DataPath, "fine-tune")
|
||||||
|
}
|
||||||
|
|
||||||
|
// jobDir returns the directory for a specific job.
|
||||||
|
func (s *FineTuneService) jobDir(jobID string) string {
|
||||||
|
return filepath.Join(s.fineTuneBaseDir(), jobID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveJobState persists a job's state to disk as state.json.
|
||||||
|
func (s *FineTuneService) saveJobState(job *schema.FineTuneJob) {
|
||||||
|
dir := s.jobDir(job.ID)
|
||||||
|
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||||
|
xlog.Error("Failed to create job directory", "job_id", job.ID, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(job, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
xlog.Error("Failed to marshal job state", "job_id", job.ID, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
statePath := filepath.Join(dir, "state.json")
|
||||||
|
if err := os.WriteFile(statePath, data, 0640); err != nil {
|
||||||
|
xlog.Error("Failed to write job state", "job_id", job.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAllJobs scans the fine-tune directory for persisted jobs and loads them.
|
||||||
|
func (s *FineTuneService) loadAllJobs() {
|
||||||
|
baseDir := s.fineTuneBaseDir()
|
||||||
|
entries, err := os.ReadDir(baseDir)
|
||||||
|
if err != nil {
|
||||||
|
// Directory doesn't exist yet — that's fine
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if !entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
statePath := filepath.Join(baseDir, entry.Name(), "state.json")
|
||||||
|
data, err := os.ReadFile(statePath)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var job schema.FineTuneJob
|
||||||
|
if err := json.Unmarshal(data, &job); err != nil {
|
||||||
|
xlog.Warn("Failed to parse job state", "path", statePath, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Jobs that were running when we shut down are now stale
|
||||||
|
if job.Status == "queued" || job.Status == "loading_model" || job.Status == "loading_dataset" || job.Status == "training" || job.Status == "saving" {
|
||||||
|
job.Status = "stopped"
|
||||||
|
job.Message = "Server restarted while job was running"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exports that were in progress are now stale
|
||||||
|
if job.ExportStatus == "exporting" {
|
||||||
|
job.ExportStatus = "failed"
|
||||||
|
job.ExportMessage = "Server restarted while export was running"
|
||||||
|
}
|
||||||
|
|
||||||
|
s.jobs[job.ID] = &job
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.jobs) > 0 {
|
||||||
|
xlog.Info("Loaded persisted fine-tune jobs", "count", len(s.jobs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartJob starts a new fine-tuning job.
|
||||||
|
func (s *FineTuneService) StartJob(ctx context.Context, userID string, req schema.FineTuneJobRequest) (*schema.FineTuneJobResponse, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
jobID := uuid.New().String()
|
||||||
|
|
||||||
|
backendName := req.Backend
|
||||||
|
if backendName == "" {
|
||||||
|
backendName = "trl"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always use DataPath for output — not user-configurable
|
||||||
|
outputDir := filepath.Join(s.fineTuneBaseDir(), jobID)
|
||||||
|
|
||||||
|
// Build gRPC request
|
||||||
|
grpcReq := &pb.FineTuneRequest{
|
||||||
|
Model: req.Model,
|
||||||
|
TrainingType: req.TrainingType,
|
||||||
|
TrainingMethod: req.TrainingMethod,
|
||||||
|
AdapterRank: req.AdapterRank,
|
||||||
|
AdapterAlpha: req.AdapterAlpha,
|
||||||
|
AdapterDropout: req.AdapterDropout,
|
||||||
|
TargetModules: req.TargetModules,
|
||||||
|
LearningRate: req.LearningRate,
|
||||||
|
NumEpochs: req.NumEpochs,
|
||||||
|
BatchSize: req.BatchSize,
|
||||||
|
GradientAccumulationSteps: req.GradientAccumulationSteps,
|
||||||
|
WarmupSteps: req.WarmupSteps,
|
||||||
|
MaxSteps: req.MaxSteps,
|
||||||
|
SaveSteps: req.SaveSteps,
|
||||||
|
WeightDecay: req.WeightDecay,
|
||||||
|
GradientCheckpointing: req.GradientCheckpointing,
|
||||||
|
Optimizer: req.Optimizer,
|
||||||
|
Seed: req.Seed,
|
||||||
|
MixedPrecision: req.MixedPrecision,
|
||||||
|
DatasetSource: req.DatasetSource,
|
||||||
|
DatasetSplit: req.DatasetSplit,
|
||||||
|
OutputDir: outputDir,
|
||||||
|
JobId: jobID,
|
||||||
|
ResumeFromCheckpoint: req.ResumeFromCheckpoint,
|
||||||
|
ExtraOptions: req.ExtraOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize reward functions into extra_options for the backend
|
||||||
|
if len(req.RewardFunctions) > 0 {
|
||||||
|
rfJSON, err := json.Marshal(req.RewardFunctions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to serialize reward functions: %w", err)
|
||||||
|
}
|
||||||
|
if grpcReq.ExtraOptions == nil {
|
||||||
|
grpcReq.ExtraOptions = make(map[string]string)
|
||||||
|
}
|
||||||
|
grpcReq.ExtraOptions["reward_funcs"] = string(rfJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the fine-tuning backend (per-job model ID so multiple jobs can run concurrently)
|
||||||
|
modelID := backendName + "-finetune-" + jobID
|
||||||
|
backendModel, err := s.modelLoader.Load(
|
||||||
|
model.WithBackendString(backendName),
|
||||||
|
model.WithModel(backendName),
|
||||||
|
model.WithModelID(modelID),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load backend %s: %w", backendName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start fine-tuning via gRPC
|
||||||
|
result, err := backendModel.StartFineTune(ctx, grpcReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start fine-tuning: %w", err)
|
||||||
|
}
|
||||||
|
if !result.Success {
|
||||||
|
return nil, fmt.Errorf("fine-tuning failed to start: %s", result.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track the job
|
||||||
|
job := &schema.FineTuneJob{
|
||||||
|
ID: jobID,
|
||||||
|
UserID: userID,
|
||||||
|
Model: req.Model,
|
||||||
|
Backend: backendName,
|
||||||
|
ModelID: modelID,
|
||||||
|
TrainingType: req.TrainingType,
|
||||||
|
TrainingMethod: req.TrainingMethod,
|
||||||
|
Status: "queued",
|
||||||
|
OutputDir: outputDir,
|
||||||
|
ExtraOptions: req.ExtraOptions,
|
||||||
|
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||||
|
Config: &req,
|
||||||
|
}
|
||||||
|
s.jobs[jobID] = job
|
||||||
|
s.saveJobState(job)
|
||||||
|
|
||||||
|
return &schema.FineTuneJobResponse{
|
||||||
|
ID: jobID,
|
||||||
|
Status: "queued",
|
||||||
|
Message: result.Message,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJob returns a fine-tuning job by ID.
|
||||||
|
func (s *FineTuneService) GetJob(userID, jobID string) (*schema.FineTuneJob, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
job, ok := s.jobs[jobID]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
if userID != "" && job.UserID != userID {
|
||||||
|
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
return job, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListJobs returns all jobs for a user, sorted by creation time (newest first).
|
||||||
|
func (s *FineTuneService) ListJobs(userID string) []*schema.FineTuneJob {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var result []*schema.FineTuneJob
|
||||||
|
for _, job := range s.jobs {
|
||||||
|
if userID == "" || job.UserID == userID {
|
||||||
|
result = append(result, job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(result, func(i, j int) bool {
|
||||||
|
return result[i].CreatedAt > result[j].CreatedAt
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopJob stops a running fine-tuning job.
|
||||||
|
func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, saveCheckpoint bool) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
job, ok := s.jobs[jobID]
|
||||||
|
if !ok {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
if userID != "" && job.UserID != userID {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Kill the backend process directly
|
||||||
|
stopModelID := job.ModelID
|
||||||
|
if stopModelID == "" {
|
||||||
|
stopModelID = job.Backend + "-finetune"
|
||||||
|
}
|
||||||
|
s.modelLoader.ShutdownModel(stopModelID)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
job.Status = "stopped"
|
||||||
|
job.Message = "Training stopped by user"
|
||||||
|
s.saveJobState(job)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteJob removes a fine-tuning job and its associated data from disk.
|
||||||
|
func (s *FineTuneService) DeleteJob(userID, jobID string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
job, ok := s.jobs[jobID]
|
||||||
|
if !ok {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
if userID != "" && job.UserID != userID {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject deletion of actively running jobs
|
||||||
|
activeStatuses := map[string]bool{
|
||||||
|
"queued": true, "loading_model": true, "loading_dataset": true,
|
||||||
|
"training": true, "saving": true,
|
||||||
|
}
|
||||||
|
if activeStatuses[job.Status] {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("cannot delete job %s: currently %s (stop it first)", jobID, job.Status)
|
||||||
|
}
|
||||||
|
if job.ExportStatus == "exporting" {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("cannot delete job %s: export in progress", jobID)
|
||||||
|
}
|
||||||
|
|
||||||
|
exportModelName := job.ExportModelName
|
||||||
|
delete(s.jobs, jobID)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Remove job directory (state.json, checkpoints, output)
|
||||||
|
jobDir := s.jobDir(jobID)
|
||||||
|
if err := os.RemoveAll(jobDir); err != nil {
|
||||||
|
xlog.Warn("Failed to remove job directory", "job_id", jobID, "path", jobDir, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If an exported model exists, clean it up too
|
||||||
|
if exportModelName != "" {
|
||||||
|
modelsPath := s.appConfig.SystemState.Model.ModelsPath
|
||||||
|
modelDir := filepath.Join(modelsPath, exportModelName)
|
||||||
|
configPath := filepath.Join(modelsPath, exportModelName+".yaml")
|
||||||
|
|
||||||
|
if err := os.RemoveAll(modelDir); err != nil {
|
||||||
|
xlog.Warn("Failed to remove exported model directory", "path", modelDir, "error", err)
|
||||||
|
}
|
||||||
|
if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
xlog.Warn("Failed to remove exported model config", "path", configPath, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload model configs
|
||||||
|
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||||
|
xlog.Warn("Failed to reload configs after delete", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Info("Deleted fine-tune job", "job_id", jobID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamProgress opens a gRPC progress stream and calls the callback for each update.
|
||||||
|
func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.FineTuneProgressEvent)) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
job, ok := s.jobs[jobID]
|
||||||
|
if !ok {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
if userID != "" && job.UserID != userID {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
streamModelID := job.ModelID
|
||||||
|
if streamModelID == "" {
|
||||||
|
streamModelID = job.Backend + "-finetune"
|
||||||
|
}
|
||||||
|
backendModel, err := s.modelLoader.Load(
|
||||||
|
model.WithBackendString(job.Backend),
|
||||||
|
model.WithModel(job.Backend),
|
||||||
|
model.WithModelID(streamModelID),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load backend: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return backendModel.FineTuneProgress(ctx, &pb.FineTuneProgressRequest{
|
||||||
|
JobId: jobID,
|
||||||
|
}, func(update *pb.FineTuneProgressUpdate) {
|
||||||
|
// Update job status and persist
|
||||||
|
s.mu.Lock()
|
||||||
|
if j, ok := s.jobs[jobID]; ok {
|
||||||
|
// Don't let progress updates overwrite terminal states
|
||||||
|
isTerminal := j.Status == "stopped" || j.Status == "completed" || j.Status == "failed"
|
||||||
|
if !isTerminal {
|
||||||
|
j.Status = update.Status
|
||||||
|
}
|
||||||
|
if update.Message != "" {
|
||||||
|
j.Message = update.Message
|
||||||
|
}
|
||||||
|
s.saveJobState(j)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Convert extra metrics
|
||||||
|
extraMetrics := make(map[string]float32)
|
||||||
|
for k, v := range update.ExtraMetrics {
|
||||||
|
extraMetrics[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
event := &schema.FineTuneProgressEvent{
|
||||||
|
JobID: update.JobId,
|
||||||
|
CurrentStep: update.CurrentStep,
|
||||||
|
TotalSteps: update.TotalSteps,
|
||||||
|
CurrentEpoch: update.CurrentEpoch,
|
||||||
|
TotalEpochs: update.TotalEpochs,
|
||||||
|
Loss: update.Loss,
|
||||||
|
LearningRate: update.LearningRate,
|
||||||
|
GradNorm: update.GradNorm,
|
||||||
|
EvalLoss: update.EvalLoss,
|
||||||
|
EtaSeconds: update.EtaSeconds,
|
||||||
|
ProgressPercent: update.ProgressPercent,
|
||||||
|
Status: update.Status,
|
||||||
|
Message: update.Message,
|
||||||
|
CheckpointPath: update.CheckpointPath,
|
||||||
|
SamplePath: update.SamplePath,
|
||||||
|
ExtraMetrics: extraMetrics,
|
||||||
|
}
|
||||||
|
callback(event)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListCheckpoints lists checkpoints for a job.
|
||||||
|
func (s *FineTuneService) ListCheckpoints(ctx context.Context, userID, jobID string) ([]*pb.CheckpointInfo, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
job, ok := s.jobs[jobID]
|
||||||
|
if !ok {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
if userID != "" && job.UserID != userID {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
ckptModelID := job.ModelID
|
||||||
|
if ckptModelID == "" {
|
||||||
|
ckptModelID = job.Backend + "-finetune"
|
||||||
|
}
|
||||||
|
backendModel, err := s.modelLoader.Load(
|
||||||
|
model.WithBackendString(job.Backend),
|
||||||
|
model.WithModel(job.Backend),
|
||||||
|
model.WithModelID(ckptModelID),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load backend: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := backendModel.ListCheckpoints(ctx, &pb.ListCheckpointsRequest{
|
||||||
|
OutputDir: job.OutputDir,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list checkpoints: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.Checkpoints, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeModelName replaces non-alphanumeric characters with hyphens and lowercases.
|
||||||
|
func sanitizeModelName(s string) string {
|
||||||
|
re := regexp.MustCompile(`[^a-zA-Z0-9\-]`)
|
||||||
|
s = re.ReplaceAllString(s, "-")
|
||||||
|
s = regexp.MustCompile(`-+`).ReplaceAllString(s, "-")
|
||||||
|
s = strings.Trim(s, "-")
|
||||||
|
return strings.ToLower(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportModel starts an async model export from a checkpoint and returns the intended model name immediately.
|
||||||
|
func (s *FineTuneService) ExportModel(ctx context.Context, userID, jobID string, req schema.ExportRequest) (string, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
job, ok := s.jobs[jobID]
|
||||||
|
if !ok {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return "", fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
if userID != "" && job.UserID != userID {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return "", fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
if job.ExportStatus == "exporting" {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return "", fmt.Errorf("export already in progress for job %s", jobID)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Compute model name
|
||||||
|
modelName := req.Name
|
||||||
|
if modelName == "" {
|
||||||
|
base := sanitizeModelName(job.Model)
|
||||||
|
if base == "" {
|
||||||
|
base = "model"
|
||||||
|
}
|
||||||
|
shortID := jobID
|
||||||
|
if len(shortID) > 8 {
|
||||||
|
shortID = shortID[:8]
|
||||||
|
}
|
||||||
|
modelName = base + "-ft-" + shortID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute output path in models directory
|
||||||
|
modelsPath := s.appConfig.SystemState.Model.ModelsPath
|
||||||
|
outputPath := filepath.Join(modelsPath, modelName)
|
||||||
|
|
||||||
|
// Check for name collision (synchronous — fast validation)
|
||||||
|
configPath := filepath.Join(modelsPath, modelName+".yaml")
|
||||||
|
if err := utils.VerifyPath(modelName+".yaml", modelsPath); err != nil {
|
||||||
|
return "", fmt.Errorf("invalid model name: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(configPath); err == nil {
|
||||||
|
return "", fmt.Errorf("model %q already exists, choose a different name", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create output directory
|
||||||
|
if err := os.MkdirAll(outputPath, 0750); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to create output directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set export status to "exporting" and persist
|
||||||
|
s.mu.Lock()
|
||||||
|
job.ExportStatus = "exporting"
|
||||||
|
job.ExportMessage = ""
|
||||||
|
job.ExportModelName = ""
|
||||||
|
s.saveJobState(job)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// Launch the export in a background goroutine
|
||||||
|
go func() {
|
||||||
|
s.setExportMessage(job, "Loading export backend...")
|
||||||
|
|
||||||
|
exportModelID := job.ModelID
|
||||||
|
if exportModelID == "" {
|
||||||
|
exportModelID = job.Backend + "-finetune"
|
||||||
|
}
|
||||||
|
backendModel, err := s.modelLoader.Load(
|
||||||
|
model.WithBackendString(job.Backend),
|
||||||
|
model.WithModel(job.Backend),
|
||||||
|
model.WithModelID(exportModelID),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
s.setExportFailed(job, fmt.Sprintf("failed to load backend: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge job's extra_options (contains hf_token from training) with request's
|
||||||
|
mergedOpts := make(map[string]string)
|
||||||
|
for k, v := range job.ExtraOptions {
|
||||||
|
mergedOpts[k] = v
|
||||||
|
}
|
||||||
|
for k, v := range req.ExtraOptions {
|
||||||
|
mergedOpts[k] = v // request overrides job
|
||||||
|
}
|
||||||
|
|
||||||
|
grpcReq := &pb.ExportModelRequest{
|
||||||
|
CheckpointPath: req.CheckpointPath,
|
||||||
|
OutputPath: outputPath,
|
||||||
|
ExportFormat: req.ExportFormat,
|
||||||
|
QuantizationMethod: req.QuantizationMethod,
|
||||||
|
Model: req.Model,
|
||||||
|
ExtraOptions: mergedOpts,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.setExportMessage(job, "Running model export (merging and converting — this may take a while)...")
|
||||||
|
|
||||||
|
result, err := backendModel.ExportModel(context.Background(), grpcReq)
|
||||||
|
if err != nil {
|
||||||
|
s.setExportFailed(job, fmt.Sprintf("export failed: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !result.Success {
|
||||||
|
s.setExportFailed(job, fmt.Sprintf("export failed: %s", result.Message))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.setExportMessage(job, "Export complete, generating model configuration...")
|
||||||
|
|
||||||
|
// Auto-import: detect format and generate config
|
||||||
|
cfg, err := importers.ImportLocalPath(outputPath, modelName)
|
||||||
|
if err != nil {
|
||||||
|
s.setExportFailed(job, fmt.Sprintf("model exported to %s but config generation failed: %v", outputPath, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Name = modelName
|
||||||
|
|
||||||
|
// If base model not detected from files, use the job's model field
|
||||||
|
if cfg.Model == "" && job.Model != "" {
|
||||||
|
cfg.Model = job.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write YAML config
|
||||||
|
yamlData, err := yaml.Marshal(cfg)
|
||||||
|
if err != nil {
|
||||||
|
s.setExportFailed(job, fmt.Sprintf("failed to marshal config: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
|
||||||
|
s.setExportFailed(job, fmt.Sprintf("failed to write config file: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.setExportMessage(job, "Registering model with LocalAI...")
|
||||||
|
|
||||||
|
// Reload configs so the model is immediately available
|
||||||
|
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||||
|
xlog.Warn("Failed to reload configs after export", "error", err)
|
||||||
|
}
|
||||||
|
if err := s.configLoader.Preload(modelsPath); err != nil {
|
||||||
|
xlog.Warn("Failed to preload after export", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
xlog.Info("Model exported and registered", "job_id", jobID, "model_name", modelName, "format", req.ExportFormat)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
job.ExportStatus = "completed"
|
||||||
|
job.ExportModelName = modelName
|
||||||
|
job.ExportMessage = ""
|
||||||
|
s.saveJobState(job)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return modelName, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setExportMessage updates the export message and persists the job state.
|
||||||
|
func (s *FineTuneService) setExportMessage(job *schema.FineTuneJob, msg string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
job.ExportMessage = msg
|
||||||
|
s.saveJobState(job)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExportedModelPath returns the path to the exported model directory and its name.
|
||||||
|
func (s *FineTuneService) GetExportedModelPath(userID, jobID string) (string, string, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
job, ok := s.jobs[jobID]
|
||||||
|
if !ok {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return "", "", fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
if userID != "" && job.UserID != userID {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return "", "", fmt.Errorf("job not found: %s", jobID)
|
||||||
|
}
|
||||||
|
if job.ExportStatus != "completed" {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return "", "", fmt.Errorf("export not completed for job %s (status: %s)", jobID, job.ExportStatus)
|
||||||
|
}
|
||||||
|
exportModelName := job.ExportModelName
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if exportModelName == "" {
|
||||||
|
return "", "", fmt.Errorf("no exported model name for job %s", jobID)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelsPath := s.appConfig.SystemState.Model.ModelsPath
|
||||||
|
modelDir := filepath.Join(modelsPath, exportModelName)
|
||||||
|
|
||||||
|
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||||
|
return "", "", fmt.Errorf("exported model directory not found: %s", modelDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelDir, exportModelName, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setExportFailed sets the export status to failed with a message.
|
||||||
|
func (s *FineTuneService) setExportFailed(job *schema.FineTuneJob, message string) {
|
||||||
|
xlog.Error("Export failed", "job_id", job.ID, "error", message)
|
||||||
|
s.mu.Lock()
|
||||||
|
job.ExportStatus = "failed"
|
||||||
|
job.ExportMessage = message
|
||||||
|
s.saveJobState(job)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadDataset handles dataset file upload and returns the local path.
|
||||||
|
func (s *FineTuneService) UploadDataset(filename string, data []byte) (string, error) {
|
||||||
|
uploadDir := filepath.Join(s.fineTuneBaseDir(), "datasets")
|
||||||
|
if err := os.MkdirAll(uploadDir, 0750); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to create dataset directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(uploadDir, uuid.New().String()[:8]+"-"+filename)
|
||||||
|
if err := os.WriteFile(filePath, data, 0640); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to write dataset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filePath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalProgressEvent converts a progress event to JSON for SSE.
|
||||||
|
func MarshalProgressEvent(event *schema.FineTuneProgressEvent) (string, error) {
|
||||||
|
data, err := json.Marshal(event)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
226
docs/content/features/fine-tuning.md
Normal file
226
docs/content/features/fine-tuning.md
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
+++
|
||||||
|
disableToc = false
|
||||||
|
title = "Fine-Tuning"
|
||||||
|
weight = 18
|
||||||
|
url = '/features/fine-tuning/'
|
||||||
|
+++
|
||||||
|
|
||||||
|
LocalAI supports fine-tuning LLMs directly through the API and Web UI. Fine-tuning is powered by pluggable backends that implement a generic gRPC interface, allowing support for different training frameworks and model types.
|
||||||
|
|
||||||
|
## Supported Backends
|
||||||
|
|
||||||
|
| Backend | Domain | GPU Required | Training Methods | Adapter Types |
|
||||||
|
|---------|--------|-------------|-----------------|---------------|
|
||||||
|
| **trl** | LLM fine-tuning | No (CPU or GPU) | SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO | LoRA, Full |
|
||||||
|
|
||||||
|
## Enabling Fine-Tuning
|
||||||
|
|
||||||
|
Fine-tuning is disabled by default. Enable it with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LOCALAI_ENABLE_FINETUNING=true local-ai
|
||||||
|
```
|
||||||
|
|
||||||
|
When authentication is enabled, fine-tuning is a per-user feature (default OFF). Admins can enable it for specific users via the user management API.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Start a fine-tuning job
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/api/fine-tuning/jobs \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"backend": "trl",
|
||||||
|
"training_method": "sft",
|
||||||
|
"training_type": "lora",
|
||||||
|
"dataset_source": "yahma/alpaca-cleaned",
|
||||||
|
"num_epochs": 1,
|
||||||
|
"batch_size": 2,
|
||||||
|
"learning_rate": 0.0002,
|
||||||
|
"adapter_rank": 16,
|
||||||
|
"adapter_alpha": 16,
|
||||||
|
"extra_options": {
|
||||||
|
"max_seq_length": "512"
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Monitor progress (SSE stream)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -N http://localhost:8080/api/fine-tuning/jobs/{job_id}/progress
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. List checkpoints
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8080/api/fine-tuning/jobs/{job_id}/checkpoints
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Export model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/api/fine-tuning/jobs/{job_id}/export \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"export_format": "gguf",
|
||||||
|
"quantization_method": "q4_k_m",
|
||||||
|
"output_path": "/models/my-finetuned-model"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### Endpoints
|
||||||
|
|
||||||
|
| Method | Path | Description |
|
||||||
|
|--------|------|-------------|
|
||||||
|
| `POST` | `/api/fine-tuning/jobs` | Start a fine-tuning job |
|
||||||
|
| `GET` | `/api/fine-tuning/jobs` | List all jobs |
|
||||||
|
| `GET` | `/api/fine-tuning/jobs/:id` | Get job details |
|
||||||
|
| `DELETE` | `/api/fine-tuning/jobs/:id` | Stop a running job |
|
||||||
|
| `GET` | `/api/fine-tuning/jobs/:id/progress` | SSE progress stream |
|
||||||
|
| `GET` | `/api/fine-tuning/jobs/:id/checkpoints` | List checkpoints |
|
||||||
|
| `POST` | `/api/fine-tuning/jobs/:id/export` | Export model |
|
||||||
|
| `POST` | `/api/fine-tuning/datasets` | Upload dataset file |
|
||||||
|
|
||||||
|
### Job Request Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `model` | string | HuggingFace model ID or local path (required) |
|
||||||
|
| `backend` | string | Backend name (default: `trl`) |
|
||||||
|
| `training_method` | string | `sft`, `dpo`, `grpo`, `rloo`, `reward`, `kto`, `orpo` |
|
||||||
|
| `training_type` | string | `lora` or `full` |
|
||||||
|
| `dataset_source` | string | HuggingFace dataset ID or local file path (required) |
|
||||||
|
| `adapter_rank` | int | LoRA rank (default: 16) |
|
||||||
|
| `adapter_alpha` | int | LoRA alpha (default: 16) |
|
||||||
|
| `num_epochs` | int | Number of training epochs (default: 3) |
|
||||||
|
| `batch_size` | int | Per-device batch size (default: 2) |
|
||||||
|
| `learning_rate` | float | Learning rate (default: 2e-4) |
|
||||||
|
| `gradient_accumulation_steps` | int | Gradient accumulation (default: 4) |
|
||||||
|
| `warmup_steps` | int | Warmup steps (default: 5) |
|
||||||
|
| `optimizer` | string | `adamw_torch`, `adamw_8bit`, `sgd`, `adafactor`, `prodigy` |
|
||||||
|
| `extra_options` | map | Backend-specific options (see below) |
|
||||||
|
|
||||||
|
### Backend-Specific Options (`extra_options`)
|
||||||
|
|
||||||
|
#### TRL
|
||||||
|
|
||||||
|
| Key | Description | Default |
|
||||||
|
|-----|-------------|---------|
|
||||||
|
| `max_seq_length` | Maximum sequence length | `512` |
|
||||||
|
| `packing` | Enable sequence packing | `false` |
|
||||||
|
| `trust_remote_code` | Trust remote code in model | `false` |
|
||||||
|
| `load_in_4bit` | Enable 4-bit quantization (GPU only) | `false` |
|
||||||
|
|
||||||
|
#### DPO-specific (training_method=dpo)
|
||||||
|
|
||||||
|
| Key | Description | Default |
|
||||||
|
|-----|-------------|---------|
|
||||||
|
| `beta` | KL penalty coefficient | `0.1` |
|
||||||
|
| `loss_type` | Loss type: `sigmoid`, `hinge`, `ipo` | `sigmoid` |
|
||||||
|
| `max_length` | Maximum sequence length | `512` |
|
||||||
|
|
||||||
|
#### GRPO-specific (training_method=grpo)
|
||||||
|
|
||||||
|
| Key | Description | Default |
|
||||||
|
|-----|-------------|---------|
|
||||||
|
| `num_generations` | Number of generations per prompt | `4` |
|
||||||
|
| `max_completion_length` | Max completion token length | `256` |
|
||||||
|
|
||||||
|
### GRPO Reward Functions
|
||||||
|
|
||||||
|
GRPO training requires reward functions to evaluate model completions. Specify them via the `reward_functions` field (a typed array) or via `extra_options["reward_funcs"]` (a JSON string).
|
||||||
|
|
||||||
|
#### Built-in Reward Functions
|
||||||
|
|
||||||
|
| Name | Description | Parameters |
|
||||||
|
|------|-------------|-----------|
|
||||||
|
| `format_reward` | Checks `<think>...</think>` then answer format (1.0/0.0) | — |
|
||||||
|
| `reasoning_accuracy_reward` | Extracts `<answer>` content, compares to dataset's `answer` column | — |
|
||||||
|
| `length_reward` | Score based on proximity to target length [0, 1] | `target_length` (default: 200) |
|
||||||
|
| `xml_tag_reward` | Scores properly opened/closed `<think>` and `<answer>` tags | — |
|
||||||
|
| `no_repetition_reward` | Penalizes n-gram repetition [0, 1] | — |
|
||||||
|
| `code_execution_reward` | Checks Python code block syntax validity (1.0/0.0) | — |
|
||||||
|
|
||||||
|
#### Inline Custom Reward Functions
|
||||||
|
|
||||||
|
You can provide custom reward function code as a Python function body. The function receives `completions` (list of strings) and `**kwargs`, and must return `list[float]`.
|
||||||
|
|
||||||
|
**Security restrictions for inline code:**
|
||||||
|
- Allowed builtins: `len`, `int`, `float`, `str`, `list`, `dict`, `range`, `enumerate`, `zip`, `map`, `filter`, `sorted`, `min`, `max`, `sum`, `abs`, `round`, `any`, `all`, `isinstance`, `print`, `True`, `False`, `None`
|
||||||
|
- Available modules: `re`, `math`, `json`, `string`
|
||||||
|
- Blocked: `open`, `__import__`, `exec`, `eval`, `compile`, `os`, `subprocess`, `getattr`, `setattr`, `delattr`, `globals`, `locals`
|
||||||
|
- Functions are compiled and validated at job start (fail-fast on syntax errors)
|
||||||
|
|
||||||
|
#### Example API Request
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/api/fine-tuning/jobs \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "Qwen/Qwen2.5-1.5B-Instruct",
|
||||||
|
"backend": "trl",
|
||||||
|
"training_method": "grpo",
|
||||||
|
"training_type": "lora",
|
||||||
|
"dataset_source": "my-reasoning-dataset",
|
||||||
|
"num_epochs": 1,
|
||||||
|
"batch_size": 2,
|
||||||
|
"learning_rate": 5e-6,
|
||||||
|
"reward_functions": [
|
||||||
|
{"type": "builtin", "name": "reasoning_accuracy_reward"},
|
||||||
|
{"type": "builtin", "name": "format_reward"},
|
||||||
|
{"type": "builtin", "name": "length_reward", "params": {"target_length": "200"}},
|
||||||
|
{"type": "inline", "name": "think_presence", "code": "return [1.0 if \"<think>\" in c else 0.0 for c in completions]"}
|
||||||
|
],
|
||||||
|
"extra_options": {
|
||||||
|
"num_generations": "4",
|
||||||
|
"max_completion_length": "256"
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Export Formats
|
||||||
|
|
||||||
|
| Format | Description | Notes |
|
||||||
|
|--------|-------------|-------|
|
||||||
|
| `lora` | LoRA adapter files | Smallest, requires base model |
|
||||||
|
| `merged_16bit` | Full model in 16-bit | Large but standalone |
|
||||||
|
| `merged_4bit` | Full model in 4-bit | Smaller, standalone |
|
||||||
|
| `gguf` | GGUF format | For llama.cpp, requires `quantization_method` |
|
||||||
|
|
||||||
|
### GGUF Quantization Methods
|
||||||
|
|
||||||
|
`q4_k_m`, `q5_k_m`, `q8_0`, `f16`, `q4_0`, `q5_0`
|
||||||
|
|
||||||
|
## Web UI
|
||||||
|
|
||||||
|
When fine-tuning is enabled, a "Fine-Tune" page appears in the sidebar under the Agents section. The UI provides:
|
||||||
|
|
||||||
|
1. **Job Configuration** — Select backend, model, training method, adapter type, and hyperparameters
|
||||||
|
2. **Dataset Upload** — Upload local datasets or reference HuggingFace datasets
|
||||||
|
3. **Training Monitor** — Real-time loss chart, progress bar, metrics display
|
||||||
|
4. **Export** — Export trained models in various formats
|
||||||
|
|
||||||
|
## Dataset Formats
|
||||||
|
|
||||||
|
Datasets should follow standard HuggingFace formats:
|
||||||
|
|
||||||
|
- **SFT**: Alpaca format (`instruction`, `input`, `output` fields) or ChatML/ShareGPT
|
||||||
|
- **DPO**: Preference pairs (`prompt`, `chosen`, `rejected` fields)
|
||||||
|
- **GRPO**: Prompts with reward signals
|
||||||
|
|
||||||
|
Supported file formats: `.json`, `.jsonl`, `.csv`
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
Fine-tuning uses the same gRPC backend architecture as inference:
|
||||||
|
|
||||||
|
1. **Proto layer**: `FineTuneRequest`, `FineTuneProgress` (streaming), `StopFineTune`, `ListCheckpoints`, `ExportModel`
|
||||||
|
2. **Python backends**: Each backend implements the gRPC interface with its specific training framework
|
||||||
|
3. **Go service**: Manages job lifecycle, routes API requests to backends
|
||||||
|
4. **REST API**: HTTP endpoints with SSE progress streaming
|
||||||
|
5. **React UI**: Configuration form, real-time training monitor, export panel
|
||||||
12
go.mod
12
go.mod
@@ -67,10 +67,18 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/chasefleming/elem-go v0.30.0 // indirect
|
||||||
|
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 // indirect
|
||||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||||
|
github.com/gofiber/template v1.8.3 // indirect
|
||||||
|
github.com/gofiber/template/html/v2 v2.1.3 // indirect
|
||||||
|
github.com/gofiber/utils v1.1.0 // indirect
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||||
|
github.com/spf13/cobra v1.10.2 // indirect
|
||||||
|
github.com/spf13/pflag v1.0.9 // indirect
|
||||||
github.com/stretchr/testify v1.11.1 // indirect
|
github.com/stretchr/testify v1.11.1 // indirect
|
||||||
github.com/tmc/langchaingo v0.1.14 // indirect
|
github.com/tmc/langchaingo v0.1.14 // indirect
|
||||||
)
|
)
|
||||||
@@ -136,8 +144,8 @@ require (
|
|||||||
github.com/kevinburke/ssh_config v1.2.0 // indirect
|
github.com/kevinburke/ssh_config v1.2.0 // indirect
|
||||||
github.com/labstack/gommon v0.4.2 // indirect
|
github.com/labstack/gommon v0.4.2 // indirect
|
||||||
github.com/mschoch/smat v0.2.0 // indirect
|
github.com/mschoch/smat v0.2.0 // indirect
|
||||||
github.com/mudler/LocalAGI v0.0.0-20260319174513-43c65ec7e88a
|
github.com/mudler/LocalAGI v0.0.0-20260321004723-b485b77037c4
|
||||||
github.com/mudler/localrecall v0.5.9-0.20260319170742-933f68603f62 // indirect
|
github.com/mudler/localrecall v0.5.9-0.20260321005011-810084e9369b // indirect
|
||||||
github.com/mudler/skillserver v0.0.5
|
github.com/mudler/skillserver v0.0.5
|
||||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||||
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect
|
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect
|
||||||
|
|||||||
24
go.sum
24
go.sum
@@ -148,6 +148,8 @@ github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf h1:rLG0Y
|
|||||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU=
|
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU=
|
||||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||||
|
github.com/chasefleming/elem-go v0.30.0 h1:BlhV1ekv1RbFiM8XZUQeln1Ikb4D+bu2eDO4agREvok=
|
||||||
|
github.com/chasefleming/elem-go v0.30.0/go.mod h1:hz73qILBIKnTgOujnSMtEj20/epI+f6vg71RUilJAA4=
|
||||||
github.com/chengxilo/virtualterm v1.0.4 h1:Z6IpERbRVlfB8WkOmtbHiDbBANU7cimRIof7mk9/PwM=
|
github.com/chengxilo/virtualterm v1.0.4 h1:Z6IpERbRVlfB8WkOmtbHiDbBANU7cimRIof7mk9/PwM=
|
||||||
github.com/chengxilo/virtualterm v1.0.4/go.mod h1:DyxxBZz/x1iqJjFxTFcr6/x+jSpqN0iwWCOK1q10rlY=
|
github.com/chengxilo/virtualterm v1.0.4/go.mod h1:DyxxBZz/x1iqJjFxTFcr6/x+jSpqN0iwWCOK1q10rlY=
|
||||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||||
@@ -177,6 +179,7 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
|
|||||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
||||||
|
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||||
github.com/creachadair/mds v0.21.3 h1:RRgEAPIb52cU0q7UxGyN+13QlCVTZIL4slRr0cYYQfA=
|
github.com/creachadair/mds v0.21.3 h1:RRgEAPIb52cU0q7UxGyN+13QlCVTZIL4slRr0cYYQfA=
|
||||||
github.com/creachadair/mds v0.21.3/go.mod h1:1ltMWZd9yXhaHEoZwBialMaviWVUpRPvMwVP7saFAzM=
|
github.com/creachadair/mds v0.21.3/go.mod h1:1ltMWZd9yXhaHEoZwBialMaviWVUpRPvMwVP7saFAzM=
|
||||||
github.com/creachadair/otp v0.5.0 h1:q3Th7CXm2zlmCdBjw5tEPFOj4oWJMnVL5HXlq0sNKS0=
|
github.com/creachadair/otp v0.5.0 h1:q3Th7CXm2zlmCdBjw5tEPFOj4oWJMnVL5HXlq0sNKS0=
|
||||||
@@ -185,6 +188,8 @@ github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
|||||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||||
github.com/cyphar/filepath-securejoin v0.5.1 h1:eYgfMq5yryL4fbWfkLpFFy2ukSELzaJOTaUTuh+oF48=
|
github.com/cyphar/filepath-securejoin v0.5.1 h1:eYgfMq5yryL4fbWfkLpFFy2ukSELzaJOTaUTuh+oF48=
|
||||||
github.com/cyphar/filepath-securejoin v0.5.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI=
|
github.com/cyphar/filepath-securejoin v0.5.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI=
|
||||||
|
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2 h1:flLYmnQFZNo04x2NPehMbf30m7Pli57xwZ0NFqR/hb0=
|
||||||
|
github.com/dave-gray101/v2keyauth v0.0.0-20240624150259-c45d584d25e2/go.mod h1:NtWqRzAp/1tw+twkW8uuBenEVVYndEAZACWU3F3xdoQ=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
@@ -341,6 +346,12 @@ github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
|||||||
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw=
|
github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw=
|
||||||
github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
|
github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
|
||||||
|
github.com/gofiber/template v1.8.3 h1:hzHdvMwMo/T2kouz2pPCA0zGiLCeMnoGsQZBTSYgZxc=
|
||||||
|
github.com/gofiber/template v1.8.3/go.mod h1:bs/2n0pSNPOkRa5VJ8zTIvedcI/lEYxzV3+YPXdBvq8=
|
||||||
|
github.com/gofiber/template/html/v2 v2.1.3 h1:n1LYBtmr9C0V/k/3qBblXyMxV5B0o/gpb6dFLp8ea+o=
|
||||||
|
github.com/gofiber/template/html/v2 v2.1.3/go.mod h1:U5Fxgc5KpyujU9OqKzy6Kn6Qup6Tm7zdsISR+VpnHRE=
|
||||||
|
github.com/gofiber/utils v1.1.0 h1:vdEBpn7AzIUJRhe+CiTOJdUcTg4Q9RK+pEa0KPbLdrM=
|
||||||
|
github.com/gofiber/utils v1.1.0/go.mod h1:poZpsnhBykfnY1Mc0KeEa6mSHrS3dV0+oBWyeQmb2e0=
|
||||||
github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw=
|
github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw=
|
||||||
github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0=
|
github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0=
|
||||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
@@ -445,6 +456,8 @@ github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI
|
|||||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||||
github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc=
|
github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc=
|
||||||
github.com/huin/goupnp v1.3.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8=
|
github.com/huin/goupnp v1.3.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||||
github.com/ipfs/boxo v0.30.0 h1:7afsoxPGGqfoH7Dum/wOTGUB9M5fb8HyKPMlLfBvIEQ=
|
github.com/ipfs/boxo v0.30.0 h1:7afsoxPGGqfoH7Dum/wOTGUB9M5fb8HyKPMlLfBvIEQ=
|
||||||
github.com/ipfs/boxo v0.30.0/go.mod h1:BPqgGGyHB9rZZcPSzah2Dc9C+5Or3U1aQe7EH1H7370=
|
github.com/ipfs/boxo v0.30.0/go.mod h1:BPqgGGyHB9rZZcPSzah2Dc9C+5Or3U1aQe7EH1H7370=
|
||||||
github.com/ipfs/go-block-format v0.2.0 h1:ZqrkxBA2ICbDRbK8KJs/u0O3dlp6gmAuuXUJNiW1Ycs=
|
github.com/ipfs/go-block-format v0.2.0 h1:ZqrkxBA2ICbDRbK8KJs/u0O3dlp6gmAuuXUJNiW1Ycs=
|
||||||
@@ -666,6 +679,8 @@ github.com/mschoch/smat v0.2.0 h1:8imxQsjDm8yFEAVBe7azKmKSgzSkZXDuKkSq9374khM=
|
|||||||
github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw=
|
github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw=
|
||||||
github.com/mudler/LocalAGI v0.0.0-20260319174513-43c65ec7e88a h1:combrnE/eLPnUhqrYmtFmqEfR6x9xS+HoTFdnMozvik=
|
github.com/mudler/LocalAGI v0.0.0-20260319174513-43c65ec7e88a h1:combrnE/eLPnUhqrYmtFmqEfR6x9xS+HoTFdnMozvik=
|
||||||
github.com/mudler/LocalAGI v0.0.0-20260319174513-43c65ec7e88a/go.mod h1:AbBcAE9JqkexN4aG8rYQn5LzmzffWqcMvQ+Nlvin3WI=
|
github.com/mudler/LocalAGI v0.0.0-20260319174513-43c65ec7e88a/go.mod h1:AbBcAE9JqkexN4aG8rYQn5LzmzffWqcMvQ+Nlvin3WI=
|
||||||
|
github.com/mudler/LocalAGI v0.0.0-20260321004723-b485b77037c4 h1:zWrAdAI/gwAPwXQAJuFLF8vvJdsxpxjKiBiC0EzhLOo=
|
||||||
|
github.com/mudler/LocalAGI v0.0.0-20260321004723-b485b77037c4/go.mod h1:g+6CD5tP4a+rRW20CrMpE/JDazq5N4n4YDxIT7tT1mY=
|
||||||
github.com/mudler/cogito v0.9.5-0.20260315222927-63abdec7189b h1:A74T2Lauvg61KodYqsjTYDY05kPLcW+efVZjd23dghU=
|
github.com/mudler/cogito v0.9.5-0.20260315222927-63abdec7189b h1:A74T2Lauvg61KodYqsjTYDY05kPLcW+efVZjd23dghU=
|
||||||
github.com/mudler/cogito v0.9.5-0.20260315222927-63abdec7189b/go.mod h1:6sfja3lcu2nWRzEc0wwqGNu/eCG3EWgij+8s7xyUeQ4=
|
github.com/mudler/cogito v0.9.5-0.20260315222927-63abdec7189b/go.mod h1:6sfja3lcu2nWRzEc0wwqGNu/eCG3EWgij+8s7xyUeQ4=
|
||||||
github.com/mudler/edgevpn v0.31.1 h1:7qegiDWd0kAg6ljhNHxqvp8hbo/6BbzSdbb7/2WZfiY=
|
github.com/mudler/edgevpn v0.31.1 h1:7qegiDWd0kAg6ljhNHxqvp8hbo/6BbzSdbb7/2WZfiY=
|
||||||
@@ -676,6 +691,10 @@ github.com/mudler/go-processmanager v0.1.0 h1:fcSKgF9U/a1Z7KofAFeZnke5YseadCI5Gq
|
|||||||
github.com/mudler/go-processmanager v0.1.0/go.mod h1:h6kmHUZeafr+k5hRYpGLMzJFH4hItHffgpRo2QIkP+o=
|
github.com/mudler/go-processmanager v0.1.0/go.mod h1:h6kmHUZeafr+k5hRYpGLMzJFH4hItHffgpRo2QIkP+o=
|
||||||
github.com/mudler/localrecall v0.5.9-0.20260319170742-933f68603f62 h1:KVTEukvLlQXKZx1C1ZLru+ahaiECLF+7v2caK8vauJ0=
|
github.com/mudler/localrecall v0.5.9-0.20260319170742-933f68603f62 h1:KVTEukvLlQXKZx1C1ZLru+ahaiECLF+7v2caK8vauJ0=
|
||||||
github.com/mudler/localrecall v0.5.9-0.20260319170742-933f68603f62/go.mod h1:/d2bG9H8G/HzsnXTTQl2bOD+ui74XwpeiSDJ+2gdkGc=
|
github.com/mudler/localrecall v0.5.9-0.20260319170742-933f68603f62/go.mod h1:/d2bG9H8G/HzsnXTTQl2bOD+ui74XwpeiSDJ+2gdkGc=
|
||||||
|
github.com/mudler/localrecall v0.5.9-0.20260321003356-422f3b1fff45 h1:+zTrbYk70wHrtvpsO2k7gMPvHYnWYCnXNxAtMex+7yg=
|
||||||
|
github.com/mudler/localrecall v0.5.9-0.20260321003356-422f3b1fff45/go.mod h1:/d2bG9H8G/HzsnXTTQl2bOD+ui74XwpeiSDJ+2gdkGc=
|
||||||
|
github.com/mudler/localrecall v0.5.9-0.20260321005011-810084e9369b h1:XeAnOEOOSKMfS5XNGpRTltQgjKCinho0V4uAhrgxN7Q=
|
||||||
|
github.com/mudler/localrecall v0.5.9-0.20260321005011-810084e9369b/go.mod h1:xuPtgL9zUyiQLmspYzO3kaboYrGbWmwi8BQPt1aCAcs=
|
||||||
github.com/mudler/memory v0.0.0-20251216220809-d1256471a6c2 h1:+WHsL/j6EWOMUiMVIOJNKOwSKiQt/qDPc9fePCf87fA=
|
github.com/mudler/memory v0.0.0-20251216220809-d1256471a6c2 h1:+WHsL/j6EWOMUiMVIOJNKOwSKiQt/qDPc9fePCf87fA=
|
||||||
github.com/mudler/memory v0.0.0-20251216220809-d1256471a6c2/go.mod h1:EA8Ashhd56o32qN7ouPKFSRUs/Z+LrRCF4v6R2Oarm8=
|
github.com/mudler/memory v0.0.0-20251216220809-d1256471a6c2/go.mod h1:EA8Ashhd56o32qN7ouPKFSRUs/Z+LrRCF4v6R2Oarm8=
|
||||||
github.com/mudler/skillserver v0.0.5 h1:t6HPpeSX8kEP7B8F5GXoQUam5VEYNmJuG6oy2/vdTu8=
|
github.com/mudler/skillserver v0.0.5 h1:t6HPpeSX8kEP7B8F5GXoQUam5VEYNmJuG6oy2/vdTu8=
|
||||||
@@ -855,6 +874,7 @@ github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR
|
|||||||
github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww=
|
github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww=
|
||||||
github.com/russross/blackfriday v1.6.0/go.mod h1:ti0ldHuxg49ri4ksnFxlkCfN+hvslNlmVHqNRXXJNAY=
|
github.com/russross/blackfriday v1.6.0/go.mod h1:ti0ldHuxg49ri4ksnFxlkCfN+hvslNlmVHqNRXXJNAY=
|
||||||
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
|
github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
|
||||||
github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU=
|
github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU=
|
||||||
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
|
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
|
||||||
@@ -928,6 +948,10 @@ github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0b
|
|||||||
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||||
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
||||||
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||||
|
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||||
|
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||||
|
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
|
||||||
|
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
|
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
|
||||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
|
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
|
||||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=
|
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=
|
||||||
|
|||||||
@@ -63,4 +63,11 @@ type Backend interface {
|
|||||||
AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error)
|
AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error)
|
||||||
|
|
||||||
ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error)
|
ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error)
|
||||||
|
|
||||||
|
// Fine-tuning
|
||||||
|
StartFineTune(ctx context.Context, in *pb.FineTuneRequest, opts ...grpc.CallOption) (*pb.FineTuneJobResult, error)
|
||||||
|
FineTuneProgress(ctx context.Context, in *pb.FineTuneProgressRequest, f func(update *pb.FineTuneProgressUpdate), opts ...grpc.CallOption) error
|
||||||
|
StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
|
ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error)
|
||||||
|
ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,6 +120,26 @@ func (llm *Base) AudioDecode(*pb.AudioDecodeRequest) (*pb.AudioDecodeResult, err
|
|||||||
return nil, fmt.Errorf("unimplemented")
|
return nil, fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm *Base) StartFineTune(*pb.FineTuneRequest) (*pb.FineTuneJobResult, error) {
|
||||||
|
return nil, fmt.Errorf("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *Base) FineTuneProgress(*pb.FineTuneProgressRequest, chan *pb.FineTuneProgressUpdate) error {
|
||||||
|
return fmt.Errorf("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *Base) StopFineTune(*pb.FineTuneStopRequest) error {
|
||||||
|
return fmt.Errorf("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *Base) ListCheckpoints(*pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error) {
|
||||||
|
return nil, fmt.Errorf("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *Base) ExportModel(*pb.ExportModelRequest) error {
|
||||||
|
return fmt.Errorf("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
func memoryUsage() *pb.MemoryUsageData {
|
func memoryUsage() *pb.MemoryUsageData {
|
||||||
mud := pb.MemoryUsageData{
|
mud := pb.MemoryUsageData{
|
||||||
Breakdown: make(map[string]uint64),
|
Breakdown: make(map[string]uint64),
|
||||||
|
|||||||
@@ -632,6 +632,142 @@ func (c *Client) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opt
|
|||||||
return client.AudioDecode(ctx, in, opts...)
|
return client.AudioDecode(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) StartFineTune(ctx context.Context, in *pb.FineTuneRequest, opts ...grpc.CallOption) (*pb.FineTuneJobResult, error) {
|
||||||
|
if !c.parallel {
|
||||||
|
c.opMutex.Lock()
|
||||||
|
defer c.opMutex.Unlock()
|
||||||
|
}
|
||||||
|
c.setBusy(true)
|
||||||
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
grpc.WithDefaultCallOptions(
|
||||||
|
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||||
|
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
client := pb.NewBackendClient(conn)
|
||||||
|
return client.StartFineTune(ctx, in, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) FineTuneProgress(ctx context.Context, in *pb.FineTuneProgressRequest, f func(update *pb.FineTuneProgressUpdate), opts ...grpc.CallOption) error {
|
||||||
|
if !c.parallel {
|
||||||
|
c.opMutex.Lock()
|
||||||
|
defer c.opMutex.Unlock()
|
||||||
|
}
|
||||||
|
c.setBusy(true)
|
||||||
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
grpc.WithDefaultCallOptions(
|
||||||
|
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||||
|
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
client := pb.NewBackendClient(conn)
|
||||||
|
|
||||||
|
stream, err := client.FineTuneProgress(ctx, in, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
update, err := stream.Recv()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f(update)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
|
if !c.parallel {
|
||||||
|
c.opMutex.Lock()
|
||||||
|
defer c.opMutex.Unlock()
|
||||||
|
}
|
||||||
|
c.setBusy(true)
|
||||||
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
grpc.WithDefaultCallOptions(
|
||||||
|
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||||
|
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
client := pb.NewBackendClient(conn)
|
||||||
|
return client.StopFineTune(ctx, in, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error) {
|
||||||
|
if !c.parallel {
|
||||||
|
c.opMutex.Lock()
|
||||||
|
defer c.opMutex.Unlock()
|
||||||
|
}
|
||||||
|
c.setBusy(true)
|
||||||
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
grpc.WithDefaultCallOptions(
|
||||||
|
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||||
|
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
client := pb.NewBackendClient(conn)
|
||||||
|
return client.ListCheckpoints(ctx, in, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
|
if !c.parallel {
|
||||||
|
c.opMutex.Lock()
|
||||||
|
defer c.opMutex.Unlock()
|
||||||
|
}
|
||||||
|
c.setBusy(true)
|
||||||
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
grpc.WithDefaultCallOptions(
|
||||||
|
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||||
|
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
client := pb.NewBackendClient(conn)
|
||||||
|
return client.ExportModel(ctx, in, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) {
|
func (c *Client) ModelMetadata(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.ModelMetadataResponse, error) {
|
||||||
if !c.parallel {
|
if !c.parallel {
|
||||||
c.opMutex.Lock()
|
c.opMutex.Lock()
|
||||||
|
|||||||
@@ -123,6 +123,68 @@ func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsReques
|
|||||||
return e.s.GetMetrics(ctx, in)
|
return e.s.GetMetrics(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *embedBackend) StartFineTune(ctx context.Context, in *pb.FineTuneRequest, opts ...grpc.CallOption) (*pb.FineTuneJobResult, error) {
|
||||||
|
return e.s.StartFineTune(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackend) FineTuneProgress(ctx context.Context, in *pb.FineTuneProgressRequest, f func(update *pb.FineTuneProgressUpdate), opts ...grpc.CallOption) error {
|
||||||
|
bs := &embedBackendFineTuneProgressStream{
|
||||||
|
ctx: ctx,
|
||||||
|
fn: f,
|
||||||
|
}
|
||||||
|
return e.s.FineTuneProgress(in, bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackend) StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
|
return e.s.StopFineTune(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackend) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest, opts ...grpc.CallOption) (*pb.ListCheckpointsResponse, error) {
|
||||||
|
return e.s.ListCheckpoints(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackend) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
|
return e.s.ExportModel(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ pb.Backend_FineTuneProgressServer = new(embedBackendFineTuneProgressStream)
|
||||||
|
|
||||||
|
type embedBackendFineTuneProgressStream struct {
|
||||||
|
ctx context.Context
|
||||||
|
fn func(update *pb.FineTuneProgressUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendFineTuneProgressStream) Send(update *pb.FineTuneProgressUpdate) error {
|
||||||
|
e.fn(update)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendFineTuneProgressStream) SetHeader(md metadata.MD) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendFineTuneProgressStream) SendHeader(md metadata.MD) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendFineTuneProgressStream) SetTrailer(md metadata.MD) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendFineTuneProgressStream) Context() context.Context {
|
||||||
|
return e.ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendFineTuneProgressStream) SendMsg(m any) error {
|
||||||
|
if x, ok := m.(*pb.FineTuneProgressUpdate); ok {
|
||||||
|
return e.Send(x)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *embedBackendFineTuneProgressStream) RecvMsg(m any) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type embedBackendServerStream struct {
|
type embedBackendServerStream struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
fn func(reply *pb.Reply)
|
fn func(reply *pb.Reply)
|
||||||
|
|||||||
@@ -35,6 +35,13 @@ type AIModel interface {
|
|||||||
AudioDecode(*pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error)
|
AudioDecode(*pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error)
|
||||||
|
|
||||||
ModelMetadata(*pb.ModelOptions) (*pb.ModelMetadataResponse, error)
|
ModelMetadata(*pb.ModelOptions) (*pb.ModelMetadataResponse, error)
|
||||||
|
|
||||||
|
// Fine-tuning
|
||||||
|
StartFineTune(*pb.FineTuneRequest) (*pb.FineTuneJobResult, error)
|
||||||
|
FineTuneProgress(*pb.FineTuneProgressRequest, chan *pb.FineTuneProgressUpdate) error
|
||||||
|
StopFineTune(*pb.FineTuneStopRequest) error
|
||||||
|
ListCheckpoints(*pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error)
|
||||||
|
ExportModel(*pb.ExportModelRequest) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func newReply(s string) *pb.Reply {
|
func newReply(s string) *pb.Reply {
|
||||||
|
|||||||
@@ -308,6 +308,75 @@ func (s *server) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest) (*p
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *server) StartFineTune(ctx context.Context, in *pb.FineTuneRequest) (*pb.FineTuneJobResult, error) {
|
||||||
|
if s.llm.Locking() {
|
||||||
|
s.llm.Lock()
|
||||||
|
defer s.llm.Unlock()
|
||||||
|
}
|
||||||
|
res, err := s.llm.StartFineTune(in)
|
||||||
|
if err != nil {
|
||||||
|
return &pb.FineTuneJobResult{Success: false, Message: fmt.Sprintf("Error starting fine-tune: %s", err.Error())}, err
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) FineTuneProgress(in *pb.FineTuneProgressRequest, stream pb.Backend_FineTuneProgressServer) error {
|
||||||
|
if s.llm.Locking() {
|
||||||
|
s.llm.Lock()
|
||||||
|
defer s.llm.Unlock()
|
||||||
|
}
|
||||||
|
updateChan := make(chan *pb.FineTuneProgressUpdate)
|
||||||
|
|
||||||
|
done := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for update := range updateChan {
|
||||||
|
stream.Send(update)
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := s.llm.FineTuneProgress(in, updateChan)
|
||||||
|
<-done
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) StopFineTune(ctx context.Context, in *pb.FineTuneStopRequest) (*pb.Result, error) {
|
||||||
|
if s.llm.Locking() {
|
||||||
|
s.llm.Lock()
|
||||||
|
defer s.llm.Unlock()
|
||||||
|
}
|
||||||
|
err := s.llm.StopFineTune(in)
|
||||||
|
if err != nil {
|
||||||
|
return &pb.Result{Message: fmt.Sprintf("Error stopping fine-tune: %s", err.Error()), Success: false}, err
|
||||||
|
}
|
||||||
|
return &pb.Result{Message: "Fine-tune stopped", Success: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) ListCheckpoints(ctx context.Context, in *pb.ListCheckpointsRequest) (*pb.ListCheckpointsResponse, error) {
|
||||||
|
if s.llm.Locking() {
|
||||||
|
s.llm.Lock()
|
||||||
|
defer s.llm.Unlock()
|
||||||
|
}
|
||||||
|
res, err := s.llm.ListCheckpoints(in)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) ExportModel(ctx context.Context, in *pb.ExportModelRequest) (*pb.Result, error) {
|
||||||
|
if s.llm.Locking() {
|
||||||
|
s.llm.Lock()
|
||||||
|
defer s.llm.Unlock()
|
||||||
|
}
|
||||||
|
err := s.llm.ExportModel(in)
|
||||||
|
if err != nil {
|
||||||
|
return &pb.Result{Message: fmt.Sprintf("Error exporting model: %s", err.Error()), Success: false}, err
|
||||||
|
}
|
||||||
|
return &pb.Result{Message: "Model exported", Success: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *server) ModelMetadata(ctx context.Context, in *pb.ModelOptions) (*pb.ModelMetadataResponse, error) {
|
func (s *server) ModelMetadata(ctx context.Context, in *pb.ModelOptions) (*pb.ModelMetadataResponse, error) {
|
||||||
if s.llm.Locking() {
|
if s.llm.Locking() {
|
||||||
s.llm.Lock()
|
s.llm.Lock()
|
||||||
|
|||||||
Reference in New Issue
Block a user