feat: add fine-tuning endpoint

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-03-20 17:09:33 +00:00
parent 9cdbd89c1f
commit ae4b758a5a
8 changed files with 2801 additions and 0 deletions

View File

@@ -0,0 +1,205 @@
package importers
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/xlog"
)
// ImportLocalPath scans a local directory for exported model files and produces
// a config.ModelConfig with the correct backend, model path, and options.
// Paths in the returned config are relative to modelsPath when possible so that
// the YAML config remains portable.
//
// Detection order:
// 1. GGUF files (*.gguf) — uses llama-cpp backend
// 2. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter
// 3. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend
func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
// Make paths relative to the models directory (parent of dirPath)
// so config YAML stays portable.
modelsDir := filepath.Dir(dirPath)
relPath := func(absPath string) string {
if rel, err := filepath.Rel(modelsDir, absPath); err == nil {
return rel
}
return absPath
}
// 1. GGUF: check dirPath and dirPath_gguf/ (Unsloth convention)
ggufFile := findGGUF(dirPath)
if ggufFile == "" {
ggufSubdir := dirPath + "_gguf"
ggufFile = findGGUF(ggufSubdir)
}
if ggufFile != "" {
xlog.Info("ImportLocalPath: detected GGUF model", "path", ggufFile)
cfg := &config.ModelConfig{
Name: name,
Backend: "llama-cpp",
KnownUsecaseStrings: []string{"chat"},
Options: []string{"use_jinja:true"},
}
cfg.Model = relPath(ggufFile)
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.Description = buildDescription(dirPath, "GGUF")
return cfg, nil
}
// 2. LoRA adapter: look for adapter_config.json
adapterConfigPath := filepath.Join(dirPath, "adapter_config.json")
if fileExists(adapterConfigPath) {
xlog.Info("ImportLocalPath: detected LoRA adapter", "path", dirPath)
baseModel := readBaseModel(dirPath)
cfg := &config.ModelConfig{
Name: name,
Backend: "transformers",
KnownUsecaseStrings: []string{"chat"},
}
cfg.Model = baseModel
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
cfg.Description = buildDescription(dirPath, "LoRA adapter")
return cfg, nil
}
// Also check for adapter_model.safetensors or adapter_model.bin without adapter_config.json
if fileExists(filepath.Join(dirPath, "adapter_model.safetensors")) || fileExists(filepath.Join(dirPath, "adapter_model.bin")) {
xlog.Info("ImportLocalPath: detected LoRA adapter (by model files)", "path", dirPath)
baseModel := readBaseModel(dirPath)
cfg := &config.ModelConfig{
Name: name,
Backend: "transformers",
KnownUsecaseStrings: []string{"chat"},
}
cfg.Model = baseModel
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
cfg.Description = buildDescription(dirPath, "LoRA adapter")
return cfg, nil
}
// 3. Merged model: *.safetensors or pytorch_model*.bin + config.json
if fileExists(filepath.Join(dirPath, "config.json")) && (hasFileWithSuffix(dirPath, ".safetensors") || hasFileWithPrefix(dirPath, "pytorch_model")) {
xlog.Info("ImportLocalPath: detected merged model", "path", dirPath)
cfg := &config.ModelConfig{
Name: name,
Backend: "transformers",
KnownUsecaseStrings: []string{"chat"},
}
cfg.Model = relPath(dirPath)
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.Description = buildDescription(dirPath, "merged model")
return cfg, nil
}
return nil, fmt.Errorf("could not detect model format in directory %s", dirPath)
}
// findGGUF returns the path to the first .gguf file found in dir, or "".
func findGGUF(dir string) string {
entries, err := os.ReadDir(dir)
if err != nil {
return ""
}
for _, e := range entries {
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), ".gguf") {
return filepath.Join(dir, e.Name())
}
}
return ""
}
// readBaseModel reads the base model name from adapter_config.json or export_metadata.json.
func readBaseModel(dirPath string) string {
// Try adapter_config.json → base_model_name_or_path (TRL writes this)
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
var ac map[string]any
if json.Unmarshal(data, &ac) == nil {
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
return bm
}
}
}
// Try export_metadata.json → base_model (Unsloth writes this)
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
var meta map[string]any
if json.Unmarshal(data, &meta) == nil {
if bm, ok := meta["base_model"].(string); ok && bm != "" {
return bm
}
}
}
return ""
}
// buildDescription creates a human-readable description using available metadata.
func buildDescription(dirPath, formatLabel string) string {
base := ""
// Try adapter_config.json
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
var ac map[string]any
if json.Unmarshal(data, &ac) == nil {
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
base = bm
}
}
}
// Try export_metadata.json
if base == "" {
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
var meta map[string]any
if json.Unmarshal(data, &meta) == nil {
if bm, ok := meta["base_model"].(string); ok && bm != "" {
base = bm
}
}
}
}
if base != "" {
return fmt.Sprintf("Fine-tuned from %s (%s)", base, formatLabel)
}
return fmt.Sprintf("Fine-tuned model (%s)", formatLabel)
}
func fileExists(path string) bool {
info, err := os.Stat(path)
return err == nil && !info.IsDir()
}
func hasFileWithSuffix(dir, suffix string) bool {
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), suffix) {
return true
}
}
return false
}
func hasFileWithPrefix(dir, prefix string) bool {
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() && strings.HasPrefix(e.Name(), prefix) {
return true
}
}
return false
}

View File

@@ -0,0 +1,148 @@
package importers_test
import (
"encoding/json"
"os"
"path/filepath"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/gallery/importers"
)
var _ = Describe("ImportLocalPath", func() {
var tmpDir string
BeforeEach(func() {
var err error
tmpDir, err = os.MkdirTemp("", "importers-local-test")
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
os.RemoveAll(tmpDir)
})
Context("GGUF detection", func() {
It("detects a GGUF file in the directory", func() {
modelDir := filepath.Join(tmpDir, "my-model")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "model-q4_k_m.gguf"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("llama-cpp"))
Expect(cfg.Model).To(ContainSubstring(".gguf"))
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
Expect(cfg.KnownUsecaseStrings).To(ContainElement("chat"))
Expect(cfg.Options).To(ContainElement("use_jinja:true"))
})
It("detects GGUF in _gguf subdirectory", func() {
modelDir := filepath.Join(tmpDir, "my-model")
ggufDir := modelDir + "_gguf"
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.MkdirAll(ggufDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(ggufDir, "model.gguf"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("llama-cpp"))
})
})
Context("LoRA adapter detection", func() {
It("detects LoRA adapter via adapter_config.json", func() {
modelDir := filepath.Join(tmpDir, "lora-model")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
adapterConfig := map[string]any{
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
"peft_type": "LORA",
}
data, _ := json.Marshal(adapterConfig)
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "lora-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("transformers"))
Expect(cfg.Model).To(Equal("meta-llama/Llama-2-7b-hf"))
Expect(cfg.LLMConfig.LoraAdapter).To(Equal(modelDir))
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
})
It("reads base model from export_metadata.json as fallback", func() {
modelDir := filepath.Join(tmpDir, "lora-unsloth")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
adapterConfig := map[string]any{"peft_type": "LORA"}
data, _ := json.Marshal(adapterConfig)
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
metadata := map[string]any{"base_model": "unsloth/tinyllama-bnb-4bit"}
data, _ = json.Marshal(metadata)
Expect(os.WriteFile(filepath.Join(modelDir, "export_metadata.json"), data, 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "lora-unsloth")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Model).To(Equal("unsloth/tinyllama-bnb-4bit"))
})
})
Context("Merged model detection", func() {
It("detects merged model with safetensors + config.json", func() {
modelDir := filepath.Join(tmpDir, "merged")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "merged")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("transformers"))
Expect(cfg.Model).To(Equal(modelDir))
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
})
It("detects merged model with pytorch_model files", func() {
modelDir := filepath.Join(tmpDir, "merged-pt")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
Expect(os.WriteFile(filepath.Join(modelDir, "pytorch_model-00001-of-00002.bin"), []byte("fake"), 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "merged-pt")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Backend).To(Equal("transformers"))
Expect(cfg.Model).To(Equal(modelDir))
})
})
Context("fallback", func() {
It("returns error for empty directory", func() {
modelDir := filepath.Join(tmpDir, "empty")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
_, err := importers.ImportLocalPath(modelDir, "empty")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("could not detect model format"))
})
})
Context("description", func() {
It("includes base model name in description", func() {
modelDir := filepath.Join(tmpDir, "desc-test")
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
adapterConfig := map[string]any{
"base_model_name_or_path": "TinyLlama/TinyLlama-1.1B",
}
data, _ := json.Marshal(adapterConfig)
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
cfg, err := importers.ImportLocalPath(modelDir, "desc-test")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.Description).To(ContainSubstring("TinyLlama/TinyLlama-1.1B"))
Expect(cfg.Description).To(ContainSubstring("Fine-tuned from"))
})
})
})

View File

@@ -0,0 +1,266 @@
package localai
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
)
// StartFineTuneJobEndpoint starts a new fine-tuning job.
func StartFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
var req schema.FineTuneJobRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "Invalid request: " + err.Error(),
})
}
if req.Model == "" {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "model is required",
})
}
if req.DatasetSource == "" {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "dataset_source is required",
})
}
resp, err := ftService.StartJob(c.Request().Context(), userID, req)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusCreated, resp)
}
}
// ListFineTuneJobsEndpoint lists fine-tuning jobs for the current user.
func ListFineTuneJobsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobs := ftService.ListJobs(userID)
if jobs == nil {
jobs = []*schema.FineTuneJob{}
}
return c.JSON(http.StatusOK, jobs)
}
}
// GetFineTuneJobEndpoint gets a specific fine-tuning job.
func GetFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
job, err := ftService.GetJob(userID, jobID)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, job)
}
}
// StopFineTuneJobEndpoint stops a running fine-tuning job.
func StopFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
// Check for save_checkpoint query param
saveCheckpoint := c.QueryParam("save_checkpoint") == "true"
err := ftService.StopJob(c.Request().Context(), userID, jobID, saveCheckpoint)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, map[string]string{
"status": "stopped",
"message": "Fine-tuning job stopped",
})
}
}
// FineTuneProgressEndpoint streams progress updates via SSE.
func FineTuneProgressEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
// Set SSE headers
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().WriteHeader(http.StatusOK)
err := ftService.StreamProgress(c.Request().Context(), userID, jobID, func(event *schema.FineTuneProgressEvent) {
data, err := json.Marshal(event)
if err != nil {
return
}
fmt.Fprintf(c.Response(), "data: %s\n\n", data)
c.Response().Flush()
})
if err != nil {
// If headers already sent, we can't send a JSON error
fmt.Fprintf(c.Response(), "data: {\"status\":\"error\",\"message\":%q}\n\n", err.Error())
c.Response().Flush()
}
return nil
}
}
// ListCheckpointsEndpoint lists checkpoints for a job.
func ListCheckpointsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
checkpoints, err := ftService.ListCheckpoints(c.Request().Context(), userID, jobID)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, map[string]any{
"checkpoints": checkpoints,
})
}
}
// ExportModelEndpoint exports a model from a checkpoint.
func ExportModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
userID := getUserID(c)
jobID := c.Param("id")
var req schema.ExportRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "Invalid request: " + err.Error(),
})
}
modelName, err := ftService.ExportModel(c.Request().Context(), userID, jobID, req)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusAccepted, map[string]string{
"status": "exporting",
"message": "Export started for model '" + modelName + "'",
"model_name": modelName,
})
}
}
// ListFineTuneBackendsEndpoint returns installed backends tagged with "fine-tuning".
func ListFineTuneBackendsEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "failed to list backends: " + err.Error(),
})
}
type backendInfo struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Tags []string `json:"tags,omitempty"`
}
var result []backendInfo
for _, b := range backends {
if !b.Installed {
continue
}
hasTag := false
for _, t := range b.Tags {
if strings.EqualFold(t, "fine-tuning") {
hasTag = true
break
}
}
if !hasTag {
continue
}
name := b.Name
if b.Alias != "" {
name = b.Alias
}
result = append(result, backendInfo{
Name: name,
Description: b.Description,
Tags: b.Tags,
})
}
if result == nil {
result = []backendInfo{}
}
return c.JSON(http.StatusOK, result)
}
}
// UploadDatasetEndpoint handles dataset file upload.
func UploadDatasetEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
return func(c echo.Context) error {
file, err := c.FormFile("file")
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{
"error": "file is required",
})
}
src, err := file.Open()
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "failed to open file",
})
}
defer src.Close()
data, err := io.ReadAll(src)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": "failed to read file",
})
}
path, err := ftService.UploadDataset(file.Filename, data)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
}
return c.JSON(http.StatusOK, map[string]string{
"path": path,
})
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,40 @@
package routes
import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/services"
)
// RegisterFineTuningRoutes registers fine-tuning API routes.
func RegisterFineTuningRoutes(e *echo.Echo, ftService *services.FineTuneService, appConfig *config.ApplicationConfig, fineTuningMw echo.MiddlewareFunc) {
if ftService == nil {
return
}
// Service readiness middleware
readyMw := func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if ftService == nil {
return c.JSON(http.StatusServiceUnavailable, map[string]string{
"error": "fine-tuning service is not available",
})
}
return next(c)
}
}
ft := e.Group("/api/fine-tuning", readyMw, fineTuningMw)
ft.GET("/backends", localai.ListFineTuneBackendsEndpoint(appConfig))
ft.POST("/jobs", localai.StartFineTuneJobEndpoint(ftService))
ft.GET("/jobs", localai.ListFineTuneJobsEndpoint(ftService))
ft.GET("/jobs/:id", localai.GetFineTuneJobEndpoint(ftService))
ft.DELETE("/jobs/:id", localai.StopFineTuneJobEndpoint(ftService))
ft.GET("/jobs/:id/progress", localai.FineTuneProgressEndpoint(ftService))
ft.GET("/jobs/:id/checkpoints", localai.ListCheckpointsEndpoint(ftService))
ft.POST("/jobs/:id/export", localai.ExportModelEndpoint(ftService))
ft.POST("/datasets", localai.UploadDatasetEndpoint(ftService))
}

99
core/schema/finetune.go Normal file
View File

@@ -0,0 +1,99 @@
package schema
// FineTuneJobRequest is the REST API request to start a fine-tuning job.
type FineTuneJobRequest struct {
Model string `json:"model"`
Backend string `json:"backend"` // "unsloth", "trl"
TrainingType string `json:"training_type,omitempty"` // lora, loha, lokr, full
TrainingMethod string `json:"training_method,omitempty"` // sft, dpo, grpo, rloo, reward, kto, orpo
// Adapter config
AdapterRank int32 `json:"adapter_rank,omitempty"`
AdapterAlpha int32 `json:"adapter_alpha,omitempty"`
AdapterDropout float32 `json:"adapter_dropout,omitempty"`
TargetModules []string `json:"target_modules,omitempty"`
// Training hyperparameters
LearningRate float32 `json:"learning_rate,omitempty"`
NumEpochs int32 `json:"num_epochs,omitempty"`
BatchSize int32 `json:"batch_size,omitempty"`
GradientAccumulationSteps int32 `json:"gradient_accumulation_steps,omitempty"`
WarmupSteps int32 `json:"warmup_steps,omitempty"`
MaxSteps int32 `json:"max_steps,omitempty"`
SaveSteps int32 `json:"save_steps,omitempty"`
WeightDecay float32 `json:"weight_decay,omitempty"`
GradientCheckpointing bool `json:"gradient_checkpointing,omitempty"`
Optimizer string `json:"optimizer,omitempty"`
Seed int32 `json:"seed,omitempty"`
MixedPrecision string `json:"mixed_precision,omitempty"`
// Dataset
DatasetSource string `json:"dataset_source"`
DatasetSplit string `json:"dataset_split,omitempty"`
// Resume from a checkpoint
ResumeFromCheckpoint string `json:"resume_from_checkpoint,omitempty"`
// Backend-specific and method-specific options
ExtraOptions map[string]string `json:"extra_options,omitempty"`
}
// FineTuneJob represents a fine-tuning job with its current state.
type FineTuneJob struct {
ID string `json:"id"`
UserID string `json:"user_id,omitempty"`
Model string `json:"model"`
Backend string `json:"backend"`
TrainingType string `json:"training_type"`
TrainingMethod string `json:"training_method"`
Status string `json:"status"` // queued, loading_model, loading_dataset, training, saving, completed, failed, stopped
Message string `json:"message,omitempty"`
OutputDir string `json:"output_dir"`
ExtraOptions map[string]string `json:"extra_options,omitempty"`
CreatedAt string `json:"created_at"`
// Export state (tracked separately from training status)
ExportStatus string `json:"export_status,omitempty"` // "", "exporting", "completed", "failed"
ExportMessage string `json:"export_message,omitempty"`
ExportModelName string `json:"export_model_name,omitempty"` // registered model name after export
// Full config for resume/reuse
Config *FineTuneJobRequest `json:"config,omitempty"`
}
// FineTuneJobResponse is the REST API response when creating a job.
type FineTuneJobResponse struct {
ID string `json:"id"`
Status string `json:"status"`
Message string `json:"message"`
}
// FineTuneProgressEvent is an SSE event for training progress.
type FineTuneProgressEvent struct {
JobID string `json:"job_id"`
CurrentStep int32 `json:"current_step"`
TotalSteps int32 `json:"total_steps"`
CurrentEpoch float32 `json:"current_epoch"`
TotalEpochs float32 `json:"total_epochs"`
Loss float32 `json:"loss"`
LearningRate float32 `json:"learning_rate"`
GradNorm float32 `json:"grad_norm"`
EvalLoss float32 `json:"eval_loss"`
EtaSeconds float32 `json:"eta_seconds"`
ProgressPercent float32 `json:"progress_percent"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
CheckpointPath string `json:"checkpoint_path,omitempty"`
SamplePath string `json:"sample_path,omitempty"`
ExtraMetrics map[string]float32 `json:"extra_metrics,omitempty"`
}
// ExportRequest is the REST API request to export a model.
type ExportRequest struct {
Name string `json:"name,omitempty"` // model name for LocalAI (auto-generated if empty)
CheckpointPath string `json:"checkpoint_path"`
ExportFormat string `json:"export_format"` // lora, merged_16bit, merged_4bit, gguf
QuantizationMethod string `json:"quantization_method"` // for GGUF: q4_k_m, q5_k_m, q8_0, f16
Model string `json:"model,omitempty"` // base model name for merge
ExtraOptions map[string]string `json:"extra_options,omitempty"`
}

564
core/services/finetune.go Normal file
View File

@@ -0,0 +1,564 @@
package services
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery/importers"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/xlog"
"gopkg.in/yaml.v3"
)
// FineTuneService manages fine-tuning jobs and their lifecycle.
type FineTuneService struct {
appConfig *config.ApplicationConfig
modelLoader *model.ModelLoader
configLoader *config.ModelConfigLoader
mu sync.Mutex
jobs map[string]*schema.FineTuneJob
}
// NewFineTuneService creates a new FineTuneService.
func NewFineTuneService(
appConfig *config.ApplicationConfig,
modelLoader *model.ModelLoader,
configLoader *config.ModelConfigLoader,
) *FineTuneService {
s := &FineTuneService{
appConfig: appConfig,
modelLoader: modelLoader,
configLoader: configLoader,
jobs: make(map[string]*schema.FineTuneJob),
}
s.loadAllJobs()
return s
}
// fineTuneBaseDir returns the base directory for fine-tune job data.
func (s *FineTuneService) fineTuneBaseDir() string {
return filepath.Join(s.appConfig.DataPath, "fine-tune")
}
// jobDir returns the directory for a specific job.
func (s *FineTuneService) jobDir(jobID string) string {
return filepath.Join(s.fineTuneBaseDir(), jobID)
}
// saveJobState persists a job's state to disk as state.json.
func (s *FineTuneService) saveJobState(job *schema.FineTuneJob) {
dir := s.jobDir(job.ID)
if err := os.MkdirAll(dir, 0750); err != nil {
xlog.Error("Failed to create job directory", "job_id", job.ID, "error", err)
return
}
data, err := json.MarshalIndent(job, "", " ")
if err != nil {
xlog.Error("Failed to marshal job state", "job_id", job.ID, "error", err)
return
}
statePath := filepath.Join(dir, "state.json")
if err := os.WriteFile(statePath, data, 0640); err != nil {
xlog.Error("Failed to write job state", "job_id", job.ID, "error", err)
}
}
// loadAllJobs scans the fine-tune directory for persisted jobs and loads them.
func (s *FineTuneService) loadAllJobs() {
baseDir := s.fineTuneBaseDir()
entries, err := os.ReadDir(baseDir)
if err != nil {
// Directory doesn't exist yet — that's fine
return
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
statePath := filepath.Join(baseDir, entry.Name(), "state.json")
data, err := os.ReadFile(statePath)
if err != nil {
continue
}
var job schema.FineTuneJob
if err := json.Unmarshal(data, &job); err != nil {
xlog.Warn("Failed to parse job state", "path", statePath, "error", err)
continue
}
// Jobs that were running when we shut down are now stale
if job.Status == "queued" || job.Status == "loading_model" || job.Status == "loading_dataset" || job.Status == "training" || job.Status == "saving" {
job.Status = "stopped"
job.Message = "Server restarted while job was running"
}
// Exports that were in progress are now stale
if job.ExportStatus == "exporting" {
job.ExportStatus = "failed"
job.ExportMessage = "Server restarted while export was running"
}
s.jobs[job.ID] = &job
}
if len(s.jobs) > 0 {
xlog.Info("Loaded persisted fine-tune jobs", "count", len(s.jobs))
}
}
// StartJob starts a new fine-tuning job.
func (s *FineTuneService) StartJob(ctx context.Context, userID string, req schema.FineTuneJobRequest) (*schema.FineTuneJobResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
jobID := uuid.New().String()
backendName := req.Backend
if backendName == "" {
backendName = "trl"
}
// Always use DataPath for output — not user-configurable
outputDir := filepath.Join(s.fineTuneBaseDir(), jobID)
// Build gRPC request
grpcReq := &pb.FineTuneRequest{
Model: req.Model,
TrainingType: req.TrainingType,
TrainingMethod: req.TrainingMethod,
AdapterRank: req.AdapterRank,
AdapterAlpha: req.AdapterAlpha,
AdapterDropout: req.AdapterDropout,
TargetModules: req.TargetModules,
LearningRate: req.LearningRate,
NumEpochs: req.NumEpochs,
BatchSize: req.BatchSize,
GradientAccumulationSteps: req.GradientAccumulationSteps,
WarmupSteps: req.WarmupSteps,
MaxSteps: req.MaxSteps,
SaveSteps: req.SaveSteps,
WeightDecay: req.WeightDecay,
GradientCheckpointing: req.GradientCheckpointing,
Optimizer: req.Optimizer,
Seed: req.Seed,
MixedPrecision: req.MixedPrecision,
DatasetSource: req.DatasetSource,
DatasetSplit: req.DatasetSplit,
OutputDir: outputDir,
JobId: jobID,
ResumeFromCheckpoint: req.ResumeFromCheckpoint,
ExtraOptions: req.ExtraOptions,
}
// Load the fine-tuning backend
backendModel, err := s.modelLoader.Load(
model.WithBackendString(backendName),
model.WithModel(backendName),
model.WithModelID(backendName+"-finetune"),
)
if err != nil {
return nil, fmt.Errorf("failed to load backend %s: %w", backendName, err)
}
// Start fine-tuning via gRPC
result, err := backendModel.StartFineTune(ctx, grpcReq)
if err != nil {
return nil, fmt.Errorf("failed to start fine-tuning: %w", err)
}
if !result.Success {
return nil, fmt.Errorf("fine-tuning failed to start: %s", result.Message)
}
// Track the job
job := &schema.FineTuneJob{
ID: jobID,
UserID: userID,
Model: req.Model,
Backend: backendName,
TrainingType: req.TrainingType,
TrainingMethod: req.TrainingMethod,
Status: "queued",
OutputDir: outputDir,
ExtraOptions: req.ExtraOptions,
CreatedAt: time.Now().UTC().Format(time.RFC3339),
Config: &req,
}
s.jobs[jobID] = job
s.saveJobState(job)
return &schema.FineTuneJobResponse{
ID: jobID,
Status: "queued",
Message: result.Message,
}, nil
}
// GetJob returns a fine-tuning job by ID.
func (s *FineTuneService) GetJob(userID, jobID string) (*schema.FineTuneJob, error) {
s.mu.Lock()
defer s.mu.Unlock()
job, ok := s.jobs[jobID]
if !ok {
return nil, fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
return nil, fmt.Errorf("job not found: %s", jobID)
}
return job, nil
}
// ListJobs returns all jobs for a user.
func (s *FineTuneService) ListJobs(userID string) []*schema.FineTuneJob {
s.mu.Lock()
defer s.mu.Unlock()
var result []*schema.FineTuneJob
for _, job := range s.jobs {
if userID == "" || job.UserID == userID {
result = append(result, job)
}
}
return result
}
// StopJob stops a running fine-tuning job.
func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, saveCheckpoint bool) error {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
s.mu.Unlock()
backendModel, err := s.modelLoader.Load(
model.WithBackendString(job.Backend),
model.WithModel(job.Backend),
model.WithModelID(job.Backend+"-finetune"),
)
if err != nil {
return fmt.Errorf("failed to load backend: %w", err)
}
_, err = backendModel.StopFineTune(ctx, &pb.FineTuneStopRequest{
JobId: jobID,
SaveCheckpoint: saveCheckpoint,
})
if err != nil {
return fmt.Errorf("failed to stop job: %w", err)
}
s.mu.Lock()
job.Status = "stopped"
s.saveJobState(job)
s.mu.Unlock()
return nil
}
// StreamProgress opens a gRPC progress stream and calls the callback for each update.
func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.FineTuneProgressEvent)) error {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
s.mu.Unlock()
backendModel, err := s.modelLoader.Load(
model.WithBackendString(job.Backend),
model.WithModel(job.Backend),
model.WithModelID(job.Backend+"-finetune"),
)
if err != nil {
return fmt.Errorf("failed to load backend: %w", err)
}
return backendModel.FineTuneProgress(ctx, &pb.FineTuneProgressRequest{
JobId: jobID,
}, func(update *pb.FineTuneProgressUpdate) {
// Update job status and persist
s.mu.Lock()
if j, ok := s.jobs[jobID]; ok {
j.Status = update.Status
if update.Message != "" {
j.Message = update.Message
}
s.saveJobState(j)
}
s.mu.Unlock()
// Convert extra metrics
extraMetrics := make(map[string]float32)
for k, v := range update.ExtraMetrics {
extraMetrics[k] = v
}
event := &schema.FineTuneProgressEvent{
JobID: update.JobId,
CurrentStep: update.CurrentStep,
TotalSteps: update.TotalSteps,
CurrentEpoch: update.CurrentEpoch,
TotalEpochs: update.TotalEpochs,
Loss: update.Loss,
LearningRate: update.LearningRate,
GradNorm: update.GradNorm,
EvalLoss: update.EvalLoss,
EtaSeconds: update.EtaSeconds,
ProgressPercent: update.ProgressPercent,
Status: update.Status,
Message: update.Message,
CheckpointPath: update.CheckpointPath,
SamplePath: update.SamplePath,
ExtraMetrics: extraMetrics,
}
callback(event)
})
}
// ListCheckpoints lists checkpoints for a job.
func (s *FineTuneService) ListCheckpoints(ctx context.Context, userID, jobID string) ([]*pb.CheckpointInfo, error) {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return nil, fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return nil, fmt.Errorf("job not found: %s", jobID)
}
s.mu.Unlock()
backendModel, err := s.modelLoader.Load(
model.WithBackendString(job.Backend),
model.WithModel(job.Backend),
model.WithModelID(job.Backend+"-finetune"),
)
if err != nil {
return nil, fmt.Errorf("failed to load backend: %w", err)
}
resp, err := backendModel.ListCheckpoints(ctx, &pb.ListCheckpointsRequest{
OutputDir: job.OutputDir,
})
if err != nil {
return nil, fmt.Errorf("failed to list checkpoints: %w", err)
}
return resp.Checkpoints, nil
}
// sanitizeModelName replaces non-alphanumeric characters with hyphens and lowercases.
func sanitizeModelName(s string) string {
re := regexp.MustCompile(`[^a-zA-Z0-9\-]`)
s = re.ReplaceAllString(s, "-")
s = regexp.MustCompile(`-+`).ReplaceAllString(s, "-")
s = strings.Trim(s, "-")
return strings.ToLower(s)
}
// ExportModel starts an async model export from a checkpoint and returns the intended model name immediately.
func (s *FineTuneService) ExportModel(ctx context.Context, userID, jobID string, req schema.ExportRequest) (string, error) {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return "", fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return "", fmt.Errorf("job not found: %s", jobID)
}
if job.ExportStatus == "exporting" {
s.mu.Unlock()
return "", fmt.Errorf("export already in progress for job %s", jobID)
}
s.mu.Unlock()
// Compute model name
modelName := req.Name
if modelName == "" {
base := sanitizeModelName(job.Model)
if base == "" {
base = "model"
}
shortID := jobID
if len(shortID) > 8 {
shortID = shortID[:8]
}
modelName = base + "-ft-" + shortID
}
// Compute output path in models directory
modelsPath := s.appConfig.SystemState.Model.ModelsPath
outputPath := filepath.Join(modelsPath, modelName)
// Check for name collision (synchronous — fast validation)
configPath := filepath.Join(modelsPath, modelName+".yaml")
if err := utils.VerifyPath(modelName+".yaml", modelsPath); err != nil {
return "", fmt.Errorf("invalid model name: %w", err)
}
if _, err := os.Stat(configPath); err == nil {
return "", fmt.Errorf("model %q already exists, choose a different name", modelName)
}
// Create output directory
if err := os.MkdirAll(outputPath, 0750); err != nil {
return "", fmt.Errorf("failed to create output directory: %w", err)
}
// Set export status to "exporting" and persist
s.mu.Lock()
job.ExportStatus = "exporting"
job.ExportMessage = ""
job.ExportModelName = ""
s.saveJobState(job)
s.mu.Unlock()
// Launch the export in a background goroutine
go func() {
backendModel, err := s.modelLoader.Load(
model.WithBackendString(job.Backend),
model.WithModel(job.Backend),
model.WithModelID(job.Backend+"-finetune"),
)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("failed to load backend: %v", err))
return
}
// Merge job's extra_options (contains hf_token from training) with request's
mergedOpts := make(map[string]string)
for k, v := range job.ExtraOptions {
mergedOpts[k] = v
}
for k, v := range req.ExtraOptions {
mergedOpts[k] = v // request overrides job
}
grpcReq := &pb.ExportModelRequest{
CheckpointPath: req.CheckpointPath,
OutputPath: outputPath,
ExportFormat: req.ExportFormat,
QuantizationMethod: req.QuantizationMethod,
Model: req.Model,
ExtraOptions: mergedOpts,
}
result, err := backendModel.ExportModel(context.Background(), grpcReq)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("export failed: %v", err))
return
}
if !result.Success {
s.setExportFailed(job, fmt.Sprintf("export failed: %s", result.Message))
return
}
// Auto-import: detect format and generate config
cfg, err := importers.ImportLocalPath(outputPath, modelName)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("model exported to %s but config generation failed: %v", outputPath, err))
return
}
cfg.Name = modelName
// If base model not detected from files, use the job's model field
if cfg.Model == "" && job.Model != "" {
cfg.Model = job.Model
}
// Write YAML config
yamlData, err := yaml.Marshal(cfg)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("failed to marshal config: %v", err))
return
}
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
s.setExportFailed(job, fmt.Sprintf("failed to write config file: %v", err))
return
}
// Reload configs so the model is immediately available
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
xlog.Warn("Failed to reload configs after export", "error", err)
}
if err := s.configLoader.Preload(modelsPath); err != nil {
xlog.Warn("Failed to preload after export", "error", err)
}
xlog.Info("Model exported and registered", "job_id", jobID, "model_name", modelName, "format", req.ExportFormat)
s.mu.Lock()
job.ExportStatus = "completed"
job.ExportModelName = modelName
job.ExportMessage = ""
s.saveJobState(job)
s.mu.Unlock()
}()
return modelName, nil
}
// setExportFailed sets the export status to failed with a message.
func (s *FineTuneService) setExportFailed(job *schema.FineTuneJob, message string) {
xlog.Error("Export failed", "job_id", job.ID, "error", message)
s.mu.Lock()
job.ExportStatus = "failed"
job.ExportMessage = message
s.saveJobState(job)
s.mu.Unlock()
}
// UploadDataset handles dataset file upload and returns the local path.
func (s *FineTuneService) UploadDataset(filename string, data []byte) (string, error) {
uploadDir := filepath.Join(s.fineTuneBaseDir(), "datasets")
if err := os.MkdirAll(uploadDir, 0750); err != nil {
return "", fmt.Errorf("failed to create dataset directory: %w", err)
}
filePath := filepath.Join(uploadDir, uuid.New().String()[:8]+"-"+filename)
if err := os.WriteFile(filePath, data, 0640); err != nil {
return "", fmt.Errorf("failed to write dataset: %w", err)
}
return filePath, nil
}
// MarshalProgressEvent converts a progress event to JSON for SSE.
func MarshalProgressEvent(event *schema.FineTuneProgressEvent) (string, error) {
data, err := json.Marshal(event)
if err != nil {
return "", err
}
return string(data), nil
}

View File

@@ -0,0 +1,185 @@
+++
disableToc = false
title = "Fine-Tuning"
weight = 18
url = '/features/fine-tuning/'
+++
LocalAI supports fine-tuning LLMs directly through the API and Web UI. Fine-tuning is powered by pluggable backends that implement a generic gRPC interface, allowing support for different training frameworks and model types.
## Supported Backends
| Backend | Domain | GPU Required | Training Methods | Adapter Types |
|---------|--------|-------------|-----------------|---------------|
| **unsloth** | LLM fine-tuning | Yes (CUDA) | SFT, GRPO | LoRA/QLoRA |
| **trl** | LLM fine-tuning | No (CPU or GPU) | SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO | LoRA, Full |
## Enabling Fine-Tuning
Fine-tuning is disabled by default. Enable it with:
```bash
LOCALAI_ENABLE_FINETUNING=true local-ai
```
When authentication is enabled, fine-tuning is a per-user feature (default OFF). Admins can enable it for specific users via the user management API.
## Quick Start
### 1. Start a fine-tuning job
```bash
curl -X POST http://localhost:8080/api/fine-tuning/jobs \
-H "Content-Type: application/json" \
-d '{
"model": "unsloth/tinyllama-bnb-4bit",
"backend": "unsloth",
"training_method": "sft",
"training_type": "lora",
"dataset_source": "yahma/alpaca-cleaned",
"num_epochs": 1,
"batch_size": 2,
"learning_rate": 0.0002,
"adapter_rank": 16,
"adapter_alpha": 16,
"extra_options": {
"max_seq_length": "2048",
"load_in_4bit": "true"
}
}'
```
### 2. Monitor progress (SSE stream)
```bash
curl -N http://localhost:8080/api/fine-tuning/jobs/{job_id}/progress
```
### 3. List checkpoints
```bash
curl http://localhost:8080/api/fine-tuning/jobs/{job_id}/checkpoints
```
### 4. Export model
```bash
curl -X POST http://localhost:8080/api/fine-tuning/jobs/{job_id}/export \
-H "Content-Type: application/json" \
-d '{
"export_format": "gguf",
"quantization_method": "q4_k_m",
"output_path": "/models/my-finetuned-model"
}'
```
## API Reference
### Endpoints
| Method | Path | Description |
|--------|------|-------------|
| `POST` | `/api/fine-tuning/jobs` | Start a fine-tuning job |
| `GET` | `/api/fine-tuning/jobs` | List all jobs |
| `GET` | `/api/fine-tuning/jobs/:id` | Get job details |
| `DELETE` | `/api/fine-tuning/jobs/:id` | Stop a running job |
| `GET` | `/api/fine-tuning/jobs/:id/progress` | SSE progress stream |
| `GET` | `/api/fine-tuning/jobs/:id/checkpoints` | List checkpoints |
| `POST` | `/api/fine-tuning/jobs/:id/export` | Export model |
| `POST` | `/api/fine-tuning/datasets` | Upload dataset file |
### Job Request Fields
| Field | Type | Description |
|-------|------|-------------|
| `model` | string | HuggingFace model ID or local path (required) |
| `backend` | string | Backend name: `unsloth` or `trl` (default: `trl`) |
| `training_method` | string | `sft`, `dpo`, `grpo`, `rloo`, `reward`, `kto`, `orpo` |
| `training_type` | string | `lora` or `full` |
| `dataset_source` | string | HuggingFace dataset ID or local file path (required) |
| `adapter_rank` | int | LoRA rank (default: 16) |
| `adapter_alpha` | int | LoRA alpha (default: 16) |
| `num_epochs` | int | Number of training epochs (default: 3) |
| `batch_size` | int | Per-device batch size (default: 2) |
| `learning_rate` | float | Learning rate (default: 2e-4) |
| `gradient_accumulation_steps` | int | Gradient accumulation (default: 4) |
| `warmup_steps` | int | Warmup steps (default: 5) |
| `optimizer` | string | `adamw_torch`, `adamw_8bit`, `sgd`, `adafactor`, `prodigy` |
| `extra_options` | map | Backend-specific options (see below) |
### Backend-Specific Options (`extra_options`)
#### Unsloth
| Key | Description | Default |
|-----|-------------|---------|
| `max_seq_length` | Maximum sequence length | `2048` |
| `load_in_4bit` | Load model in 4-bit quantization | `true` |
| `packing` | Enable sequence packing | `false` |
| `use_rslora` | Use Rank-Stabilized LoRA | `false` |
#### TRL
| Key | Description | Default |
|-----|-------------|---------|
| `max_seq_length` | Maximum sequence length | `512` |
| `packing` | Enable sequence packing | `false` |
| `trust_remote_code` | Trust remote code in model | `false` |
| `load_in_4bit` | Enable 4-bit quantization (GPU only) | `false` |
#### DPO-specific (training_method=dpo)
| Key | Description | Default |
|-----|-------------|---------|
| `beta` | KL penalty coefficient | `0.1` |
| `loss_type` | Loss type: `sigmoid`, `hinge`, `ipo` | `sigmoid` |
| `max_length` | Maximum sequence length | `512` |
#### GRPO-specific (training_method=grpo)
| Key | Description | Default |
|-----|-------------|---------|
| `num_generations` | Number of generations per prompt | `4` |
| `max_completion_length` | Max completion token length | `256` |
### Export Formats
| Format | Description | Notes |
|--------|-------------|-------|
| `lora` | LoRA adapter files | Smallest, requires base model |
| `merged_16bit` | Full model in 16-bit | Large but standalone |
| `merged_4bit` | Full model in 4-bit | Smaller, standalone |
| `gguf` | GGUF format | For llama.cpp, requires `quantization_method` |
### GGUF Quantization Methods
`q4_k_m`, `q5_k_m`, `q8_0`, `f16`, `q4_0`, `q5_0`
## Web UI
When fine-tuning is enabled, a "Fine-Tune" page appears in the sidebar under the Agents section. The UI provides:
1. **Job Configuration** — Select backend, model, training method, adapter type, and hyperparameters
2. **Dataset Upload** — Upload local datasets or reference HuggingFace datasets
3. **Training Monitor** — Real-time loss chart, progress bar, metrics display
4. **Export** — Export trained models in various formats
## Dataset Formats
Datasets should follow standard HuggingFace formats:
- **SFT**: Alpaca format (`instruction`, `input`, `output` fields) or ChatML/ShareGPT
- **DPO**: Preference pairs (`prompt`, `chosen`, `rejected` fields)
- **GRPO**: Prompts with reward signals
Supported file formats: `.json`, `.jsonl`, `.csv`
## Architecture
Fine-tuning uses the same gRPC backend architecture as inference:
1. **Proto layer**: `FineTuneRequest`, `FineTuneProgress` (streaming), `StopFineTune`, `ListCheckpoints`, `ExportModel`
2. **Python backends**: Each backend implements the gRPC interface with its specific training framework
3. **Go service**: Manages job lifecycle, routes API requests to backends
4. **REST API**: HTTP endpoints with SSE progress streaming
5. **React UI**: Configuration form, real-time training monitor, export panel