diff --git a/core/gallery/importers/local.go b/core/gallery/importers/local.go new file mode 100644 index 000000000..2a456cc60 --- /dev/null +++ b/core/gallery/importers/local.go @@ -0,0 +1,205 @@ +package importers + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/xlog" +) + +// ImportLocalPath scans a local directory for exported model files and produces +// a config.ModelConfig with the correct backend, model path, and options. +// Paths in the returned config are relative to modelsPath when possible so that +// the YAML config remains portable. +// +// Detection order: +// 1. GGUF files (*.gguf) — uses llama-cpp backend +// 2. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter +// 3. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend +func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) { + // Make paths relative to the models directory (parent of dirPath) + // so config YAML stays portable. + modelsDir := filepath.Dir(dirPath) + relPath := func(absPath string) string { + if rel, err := filepath.Rel(modelsDir, absPath); err == nil { + return rel + } + return absPath + } + + // 1. GGUF: check dirPath and dirPath_gguf/ (Unsloth convention) + ggufFile := findGGUF(dirPath) + if ggufFile == "" { + ggufSubdir := dirPath + "_gguf" + ggufFile = findGGUF(ggufSubdir) + } + if ggufFile != "" { + xlog.Info("ImportLocalPath: detected GGUF model", "path", ggufFile) + cfg := &config.ModelConfig{ + Name: name, + Backend: "llama-cpp", + KnownUsecaseStrings: []string{"chat"}, + Options: []string{"use_jinja:true"}, + } + cfg.Model = relPath(ggufFile) + cfg.TemplateConfig.UseTokenizerTemplate = true + cfg.Description = buildDescription(dirPath, "GGUF") + return cfg, nil + } + + // 2. LoRA adapter: look for adapter_config.json + + adapterConfigPath := filepath.Join(dirPath, "adapter_config.json") + if fileExists(adapterConfigPath) { + xlog.Info("ImportLocalPath: detected LoRA adapter", "path", dirPath) + baseModel := readBaseModel(dirPath) + cfg := &config.ModelConfig{ + Name: name, + Backend: "transformers", + KnownUsecaseStrings: []string{"chat"}, + } + cfg.Model = baseModel + cfg.TemplateConfig.UseTokenizerTemplate = true + cfg.LLMConfig.LoraAdapter = relPath(dirPath) + cfg.Description = buildDescription(dirPath, "LoRA adapter") + return cfg, nil + } + + // Also check for adapter_model.safetensors or adapter_model.bin without adapter_config.json + if fileExists(filepath.Join(dirPath, "adapter_model.safetensors")) || fileExists(filepath.Join(dirPath, "adapter_model.bin")) { + xlog.Info("ImportLocalPath: detected LoRA adapter (by model files)", "path", dirPath) + baseModel := readBaseModel(dirPath) + cfg := &config.ModelConfig{ + Name: name, + Backend: "transformers", + KnownUsecaseStrings: []string{"chat"}, + } + cfg.Model = baseModel + cfg.TemplateConfig.UseTokenizerTemplate = true + cfg.LLMConfig.LoraAdapter = relPath(dirPath) + cfg.Description = buildDescription(dirPath, "LoRA adapter") + return cfg, nil + } + + // 3. Merged model: *.safetensors or pytorch_model*.bin + config.json + if fileExists(filepath.Join(dirPath, "config.json")) && (hasFileWithSuffix(dirPath, ".safetensors") || hasFileWithPrefix(dirPath, "pytorch_model")) { + xlog.Info("ImportLocalPath: detected merged model", "path", dirPath) + cfg := &config.ModelConfig{ + Name: name, + Backend: "transformers", + KnownUsecaseStrings: []string{"chat"}, + } + cfg.Model = relPath(dirPath) + cfg.TemplateConfig.UseTokenizerTemplate = true + cfg.Description = buildDescription(dirPath, "merged model") + return cfg, nil + } + + return nil, fmt.Errorf("could not detect model format in directory %s", dirPath) +} + +// findGGUF returns the path to the first .gguf file found in dir, or "". +func findGGUF(dir string) string { + entries, err := os.ReadDir(dir) + if err != nil { + return "" + } + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), ".gguf") { + return filepath.Join(dir, e.Name()) + } + } + return "" +} + +// readBaseModel reads the base model name from adapter_config.json or export_metadata.json. +func readBaseModel(dirPath string) string { + // Try adapter_config.json → base_model_name_or_path (TRL writes this) + if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil { + var ac map[string]any + if json.Unmarshal(data, &ac) == nil { + if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" { + return bm + } + } + } + + // Try export_metadata.json → base_model (Unsloth writes this) + if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil { + var meta map[string]any + if json.Unmarshal(data, &meta) == nil { + if bm, ok := meta["base_model"].(string); ok && bm != "" { + return bm + } + } + } + + return "" +} + +// buildDescription creates a human-readable description using available metadata. +func buildDescription(dirPath, formatLabel string) string { + base := "" + + // Try adapter_config.json + if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil { + var ac map[string]any + if json.Unmarshal(data, &ac) == nil { + if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" { + base = bm + } + } + } + + // Try export_metadata.json + if base == "" { + if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil { + var meta map[string]any + if json.Unmarshal(data, &meta) == nil { + if bm, ok := meta["base_model"].(string); ok && bm != "" { + base = bm + } + } + } + } + + if base != "" { + return fmt.Sprintf("Fine-tuned from %s (%s)", base, formatLabel) + } + return fmt.Sprintf("Fine-tuned model (%s)", formatLabel) +} + +func fileExists(path string) bool { + info, err := os.Stat(path) + return err == nil && !info.IsDir() +} + +func hasFileWithSuffix(dir, suffix string) bool { + entries, err := os.ReadDir(dir) + if err != nil { + return false + } + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), suffix) { + return true + } + } + return false +} + +func hasFileWithPrefix(dir, prefix string) bool { + entries, err := os.ReadDir(dir) + if err != nil { + return false + } + for _, e := range entries { + if !e.IsDir() && strings.HasPrefix(e.Name(), prefix) { + return true + } + } + return false +} diff --git a/core/gallery/importers/local_test.go b/core/gallery/importers/local_test.go new file mode 100644 index 000000000..b43433406 --- /dev/null +++ b/core/gallery/importers/local_test.go @@ -0,0 +1,148 @@ +package importers_test + +import ( + "encoding/json" + "os" + "path/filepath" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/gallery/importers" +) + +var _ = Describe("ImportLocalPath", func() { + var tmpDir string + + BeforeEach(func() { + var err error + tmpDir, err = os.MkdirTemp("", "importers-local-test") + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + os.RemoveAll(tmpDir) + }) + + Context("GGUF detection", func() { + It("detects a GGUF file in the directory", func() { + modelDir := filepath.Join(tmpDir, "my-model") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(modelDir, "model-q4_k_m.gguf"), []byte("fake"), 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "my-model") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Backend).To(Equal("llama-cpp")) + Expect(cfg.Model).To(ContainSubstring(".gguf")) + Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue()) + Expect(cfg.KnownUsecaseStrings).To(ContainElement("chat")) + Expect(cfg.Options).To(ContainElement("use_jinja:true")) + }) + + It("detects GGUF in _gguf subdirectory", func() { + modelDir := filepath.Join(tmpDir, "my-model") + ggufDir := modelDir + "_gguf" + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + Expect(os.MkdirAll(ggufDir, 0755)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(ggufDir, "model.gguf"), []byte("fake"), 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "my-model") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Backend).To(Equal("llama-cpp")) + }) + }) + + Context("LoRA adapter detection", func() { + It("detects LoRA adapter via adapter_config.json", func() { + modelDir := filepath.Join(tmpDir, "lora-model") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + + adapterConfig := map[string]any{ + "base_model_name_or_path": "meta-llama/Llama-2-7b-hf", + "peft_type": "LORA", + } + data, _ := json.Marshal(adapterConfig) + Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "lora-model") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Backend).To(Equal("transformers")) + Expect(cfg.Model).To(Equal("meta-llama/Llama-2-7b-hf")) + Expect(cfg.LLMConfig.LoraAdapter).To(Equal(modelDir)) + Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue()) + }) + + It("reads base model from export_metadata.json as fallback", func() { + modelDir := filepath.Join(tmpDir, "lora-unsloth") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + + adapterConfig := map[string]any{"peft_type": "LORA"} + data, _ := json.Marshal(adapterConfig) + Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed()) + + metadata := map[string]any{"base_model": "unsloth/tinyllama-bnb-4bit"} + data, _ = json.Marshal(metadata) + Expect(os.WriteFile(filepath.Join(modelDir, "export_metadata.json"), data, 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "lora-unsloth") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Model).To(Equal("unsloth/tinyllama-bnb-4bit")) + }) + }) + + Context("Merged model detection", func() { + It("detects merged model with safetensors + config.json", func() { + modelDir := filepath.Join(tmpDir, "merged") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("fake"), 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "merged") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Backend).To(Equal("transformers")) + Expect(cfg.Model).To(Equal(modelDir)) + Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue()) + }) + + It("detects merged model with pytorch_model files", func() { + modelDir := filepath.Join(tmpDir, "merged-pt") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed()) + Expect(os.WriteFile(filepath.Join(modelDir, "pytorch_model-00001-of-00002.bin"), []byte("fake"), 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "merged-pt") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Backend).To(Equal("transformers")) + Expect(cfg.Model).To(Equal(modelDir)) + }) + }) + + Context("fallback", func() { + It("returns error for empty directory", func() { + modelDir := filepath.Join(tmpDir, "empty") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + + _, err := importers.ImportLocalPath(modelDir, "empty") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("could not detect model format")) + }) + }) + + Context("description", func() { + It("includes base model name in description", func() { + modelDir := filepath.Join(tmpDir, "desc-test") + Expect(os.MkdirAll(modelDir, 0755)).To(Succeed()) + + adapterConfig := map[string]any{ + "base_model_name_or_path": "TinyLlama/TinyLlama-1.1B", + } + data, _ := json.Marshal(adapterConfig) + Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed()) + + cfg, err := importers.ImportLocalPath(modelDir, "desc-test") + Expect(err).ToNot(HaveOccurred()) + Expect(cfg.Description).To(ContainSubstring("TinyLlama/TinyLlama-1.1B")) + Expect(cfg.Description).To(ContainSubstring("Fine-tuned from")) + }) + }) +}) diff --git a/core/http/endpoints/localai/finetune.go b/core/http/endpoints/localai/finetune.go new file mode 100644 index 000000000..ddc0b834a --- /dev/null +++ b/core/http/endpoints/localai/finetune.go @@ -0,0 +1,266 @@ +package localai + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services" +) + +// StartFineTuneJobEndpoint starts a new fine-tuning job. +func StartFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + + var req schema.FineTuneJobRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "Invalid request: " + err.Error(), + }) + } + + if req.Model == "" { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "model is required", + }) + } + if req.DatasetSource == "" { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "dataset_source is required", + }) + } + + resp, err := ftService.StartJob(c.Request().Context(), userID, req) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusCreated, resp) + } +} + +// ListFineTuneJobsEndpoint lists fine-tuning jobs for the current user. +func ListFineTuneJobsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobs := ftService.ListJobs(userID) + if jobs == nil { + jobs = []*schema.FineTuneJob{} + } + return c.JSON(http.StatusOK, jobs) + } +} + +// GetFineTuneJobEndpoint gets a specific fine-tuning job. +func GetFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + job, err := ftService.GetJob(userID, jobID) + if err != nil { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, job) + } +} + +// StopFineTuneJobEndpoint stops a running fine-tuning job. +func StopFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + // Check for save_checkpoint query param + saveCheckpoint := c.QueryParam("save_checkpoint") == "true" + + err := ftService.StopJob(c.Request().Context(), userID, jobID, saveCheckpoint) + if err != nil { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]string{ + "status": "stopped", + "message": "Fine-tuning job stopped", + }) + } +} + +// FineTuneProgressEndpoint streams progress updates via SSE. +func FineTuneProgressEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + // Set SSE headers + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + c.Response().WriteHeader(http.StatusOK) + + err := ftService.StreamProgress(c.Request().Context(), userID, jobID, func(event *schema.FineTuneProgressEvent) { + data, err := json.Marshal(event) + if err != nil { + return + } + fmt.Fprintf(c.Response(), "data: %s\n\n", data) + c.Response().Flush() + }) + if err != nil { + // If headers already sent, we can't send a JSON error + fmt.Fprintf(c.Response(), "data: {\"status\":\"error\",\"message\":%q}\n\n", err.Error()) + c.Response().Flush() + } + + return nil + } +} + +// ListCheckpointsEndpoint lists checkpoints for a job. +func ListCheckpointsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + checkpoints, err := ftService.ListCheckpoints(c.Request().Context(), userID, jobID) + if err != nil { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]any{ + "checkpoints": checkpoints, + }) + } +} + +// ExportModelEndpoint exports a model from a checkpoint. +func ExportModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + userID := getUserID(c) + jobID := c.Param("id") + + var req schema.ExportRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "Invalid request: " + err.Error(), + }) + } + + modelName, err := ftService.ExportModel(c.Request().Context(), userID, jobID, req) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusAccepted, map[string]string{ + "status": "exporting", + "message": "Export started for model '" + modelName + "'", + "model_name": modelName, + }) + } +} + +// ListFineTuneBackendsEndpoint returns installed backends tagged with "fine-tuning". +func ListFineTuneBackendsEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": "failed to list backends: " + err.Error(), + }) + } + + type backendInfo struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Tags []string `json:"tags,omitempty"` + } + + var result []backendInfo + for _, b := range backends { + if !b.Installed { + continue + } + hasTag := false + for _, t := range b.Tags { + if strings.EqualFold(t, "fine-tuning") { + hasTag = true + break + } + } + if !hasTag { + continue + } + name := b.Name + if b.Alias != "" { + name = b.Alias + } + result = append(result, backendInfo{ + Name: name, + Description: b.Description, + Tags: b.Tags, + }) + } + + if result == nil { + result = []backendInfo{} + } + + return c.JSON(http.StatusOK, result) + } +} + +// UploadDatasetEndpoint handles dataset file upload. +func UploadDatasetEndpoint(ftService *services.FineTuneService) echo.HandlerFunc { + return func(c echo.Context) error { + file, err := c.FormFile("file") + if err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "file is required", + }) + } + + src, err := file.Open() + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": "failed to open file", + }) + } + defer src.Close() + + data, err := io.ReadAll(src) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": "failed to read file", + }) + } + + path, err := ftService.UploadDataset(file.Filename, data) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]string{ + "path": path, + }) + } +} diff --git a/core/http/react-ui/src/pages/FineTune.jsx b/core/http/react-ui/src/pages/FineTune.jsx new file mode 100644 index 000000000..701d7f891 --- /dev/null +++ b/core/http/react-ui/src/pages/FineTune.jsx @@ -0,0 +1,1294 @@ +import { useState, useEffect, useRef, useCallback } from 'react' +import { fineTuneApi } from '../utils/api' +import LoadingSpinner from '../components/LoadingSpinner' + +const TRAINING_METHODS = ['sft', 'dpo', 'grpo', 'rloo', 'reward', 'kto', 'orpo'] +const TRAINING_TYPES = ['lora', 'loha', 'lokr', 'full'] +const FALLBACK_BACKENDS = ['trl'] +const OPTIMIZERS = ['adamw_torch', 'adamw_8bit', 'sgd', 'adafactor', 'prodigy'] +const MIXED_PRECISION_OPTS = ['', 'fp16', 'bf16', 'no'] + +const statusBadgeClass = { + queued: '', + loading_model: 'badge-warning', + loading_dataset: 'badge-warning', + training: 'badge-info', + saving: 'badge-info', + completed: 'badge-success', + failed: 'badge-error', + stopped: '', +} + +function FormSection({ icon, title, children }) { + return ( +
+

+ + {title} +

+ {children} +
+ ) +} + +function KeyValueEditor({ entries, onChange }) { + const addEntry = () => onChange([...entries, { key: '', value: '' }]) + const removeEntry = (i) => onChange(entries.filter((_, idx) => idx !== i)) + const updateEntry = (i, field, val) => { + const updated = entries.map((e, idx) => idx === i ? { ...e, [field]: val } : e) + onChange(updated) + } + + return ( +
+ {entries.map((entry, i) => ( +
+ updateEntry(i, 'key', e.target.value)} + placeholder="Key" + style={{ flex: 1 }} + /> + updateEntry(i, 'value', e.target.value)} + placeholder="Value" + style={{ flex: 2 }} + /> + +
+ ))} + +
+ ) +} + +function CopyButton({ text }) { + const [copied, setCopied] = useState(false) + const handleCopy = (e) => { + e.stopPropagation() + navigator.clipboard.writeText(text).then(() => { + setCopied(true) + setTimeout(() => setCopied(false), 1500) + }) + } + return ( + + ) +} + +function JobCard({ job, isSelected, onSelect, onUseConfig }) { + return ( +
onSelect(job)} + > +
+
+ {job.model} + + {job.backend} / {job.training_method || 'sft'} + +
+
+ + + {job.status} + +
+
+
+ ID: {job.id?.slice(0, 8)}... | Created: {job.created_at} +
+ {job.output_dir && ( +
+ + + {job.output_dir} + + +
+ )} + {job.message && ( +
+ + {job.message} +
+ )} +
+ ) +} + +function formatEta(seconds) { + if (!seconds || seconds <= 0) return '--' + const h = Math.floor(seconds / 3600) + const m = Math.floor((seconds % 3600) / 60) + const s = Math.floor(seconds % 60) + if (h > 0) return `${h}h ${m}m` + if (m > 0) return `${m}m ${s}s` + return `${s}s` +} + +function formatAxisValue(val, decimals) { + if (val >= 1) return val.toFixed(Math.min(decimals, 1)) + if (val >= 0.01) return val.toFixed(Math.min(decimals, 3)) + return val.toExponential(1) +} + +function TrainingChart({ events }) { + const [tooltip, setTooltip] = useState(null) + const svgRef = useRef(null) + + if (!events || events.length < 2) return null + + const pad = { top: 20, right: 60, bottom: 40, left: 60 } + const W = 600, H = 300 + const cw = W - pad.left - pad.right + const ch = H - pad.top - pad.bottom + + const steps = events.map(e => e.current_step) + const losses = events.map(e => e.loss) + const lrs = events.map(e => e.learning_rate).filter(v => v != null && v > 0) + const hasLr = lrs.length > 1 + + const minStep = Math.min(...steps), maxStep = Math.max(...steps) + const stepRange = maxStep - minStep || 1 + const minLoss = Math.min(...losses), maxLoss = Math.max(...losses) + const lossRange = maxLoss - minLoss || 1 + const lossPad = lossRange * 0.05 + const yMin = Math.max(0, minLoss - lossPad), yMax = maxLoss + lossPad + const yRange = yMax - yMin || 1 + + const x = (step) => pad.left + ((step - minStep) / stepRange) * cw + const yLoss = (loss) => pad.top + (1 - (loss - yMin) / yRange) * ch + + // Loss polyline + const lossPoints = events.map(e => `${x(e.current_step)},${yLoss(e.loss)}`).join(' ') + + // Learning rate polyline (scaled to right axis) + let lrPoints = '' + let lrMin = 0, lrMax = 1, lrRange = 1 + if (hasLr) { + lrMin = Math.min(...lrs) + lrMax = Math.max(...lrs) + lrRange = lrMax - lrMin || 1 + const lrPad = lrRange * 0.05 + lrMin = Math.max(0, lrMin - lrPad) + lrMax = lrMax + lrPad + lrRange = lrMax - lrMin || 1 + const yLr = (lr) => pad.top + (1 - (lr - lrMin) / lrRange) * ch + lrPoints = events + .filter(e => e.learning_rate != null && e.learning_rate > 0) + .map(e => `${x(e.current_step)},${yLr(e.learning_rate)}`) + .join(' ') + } + + // Axis ticks + const xTickCount = Math.min(6, events.length) + const xTicks = Array.from({ length: xTickCount }, (_, i) => { + const step = minStep + (stepRange * i) / (xTickCount - 1) + return Math.round(step) + }) + + const yTickCount = 5 + const yTicks = Array.from({ length: yTickCount }, (_, i) => { + return yMin + (yRange * i) / (yTickCount - 1) + }) + + // LR axis ticks (right) + const lrTicks = hasLr ? Array.from({ length: yTickCount }, (_, i) => { + return lrMin + (lrRange * i) / (yTickCount - 1) + }) : [] + const yLrTick = (lr) => pad.top + (1 - (lr - lrMin) / lrRange) * ch + + // Epoch boundary markers + const epochBoundaries = [] + for (let i = 1; i < events.length; i++) { + const prevEpoch = Math.floor(events[i - 1].current_epoch || 0) + const curEpoch = Math.floor(events[i].current_epoch || 0) + if (curEpoch > prevEpoch && curEpoch > 0) { + epochBoundaries.push({ step: events[i].current_step, epoch: curEpoch }) + } + } + + const handleMouseMove = (e) => { + if (!svgRef.current) return + const rect = svgRef.current.getBoundingClientRect() + const mx = ((e.clientX - rect.left) / rect.width) * W + const step = minStep + ((mx - pad.left) / cw) * stepRange + // Find nearest event + let nearest = events[0], bestDist = Infinity + for (const ev of events) { + const d = Math.abs(ev.current_step - step) + if (d < bestDist) { bestDist = d; nearest = ev } + } + setTooltip({ x: x(nearest.current_step), y: yLoss(nearest.loss), data: nearest }) + } + + return ( +
+
+ Training Curves + + Loss + + {hasLr && ( + + Learning Rate + + )} +
+ setTooltip(null)} + > + {/* Grid lines */} + {yTicks.map((val, i) => ( + + ))} + + {/* Epoch boundary markers */} + {epochBoundaries.map((eb, i) => ( + + + + Epoch {eb.epoch} + + + ))} + + {/* Loss curve */} + + + {/* Learning rate curve */} + {hasLr && lrPoints && ( + + )} + + {/* X axis */} + + {xTicks.map((step, i) => ( + + + + {step} + + + ))} + + Step + + + {/* Y axis (left - Loss) */} + + {yTicks.map((val, i) => ( + + + + {formatAxisValue(val, 3)} + + + ))} + + Loss + + + {/* Y axis (right - Learning Rate) */} + {hasLr && ( + <> + + {lrTicks.map((val, i) => ( + + + + {val.toExponential(0)} + + + ))} + + LR + + + )} + + {/* Tooltip */} + {tooltip && ( + + + + + + Step: {tooltip.data.current_step} | Epoch: {(tooltip.data.current_epoch || 0).toFixed(1)} + + + Loss: {tooltip.data.loss?.toFixed(4)} + + {tooltip.data.learning_rate > 0 && ( + + LR: {tooltip.data.learning_rate?.toExponential(2)} + + )} + + )} + +
+ ) +} + +function TrainingMonitor({ job, onStop }) { + const [events, setEvents] = useState([]) + const [latest, setLatest] = useState(null) + const eventSourceRef = useRef(null) + + useEffect(() => { + if (!job || !['queued', 'loading_model', 'loading_dataset', 'training', 'saving'].includes(job.status)) return + + const url = fineTuneApi.progressUrl(job.id) + const es = new EventSource(url) + eventSourceRef.current = es + + es.onmessage = (e) => { + try { + const data = JSON.parse(e.data) + setLatest(data) + if (data.loss > 0) { + setEvents(prev => [...prev, data]) + } + if (['completed', 'failed', 'stopped'].includes(data.status)) { + es.close() + } + } catch (_) {} + } + + es.onerror = () => { + es.close() + } + + return () => { + es.close() + } + }, [job]) + + if (!job) return null + + return ( +
+

+ + Training Monitor +

+ + {latest && ( +
+
+
Status
+
{latest.status}
+
+
+
Progress
+
{latest.progress_percent?.toFixed(1)}%
+
+
+
Step
+
{latest.current_step} / {latest.total_steps}
+
+
+
Loss
+
{latest.loss?.toFixed(4)}
+
+
+
Epoch
+
{latest.current_epoch?.toFixed(2)} / {latest.total_epochs?.toFixed(0)}
+
+
+
Learning Rate
+
{latest.learning_rate?.toExponential(2)}
+
+
+
ETA
+
{formatEta(latest.eta_seconds)}
+
+ {latest.extra_metrics?.tokens_per_second > 0 && ( +
+
Tokens/sec
+
{latest.extra_metrics.tokens_per_second.toFixed(0)}
+
+ )} +
+ )} + + {/* Progress bar */} + {latest && ( +
+
+
+ )} + + {/* Training chart */} + + + {latest?.message && ( +
+ + {latest.message} +
+ )} + + {['queued', 'loading_model', 'loading_dataset', 'training', 'saving'].includes(latest?.status || job.status) && ( + + )} +
+ ) +} + +function CheckpointsPanel({ job, onResume, onExportCheckpoint }) { + const [checkpoints, setCheckpoints] = useState([]) + const [loading, setLoading] = useState(false) + + useEffect(() => { + if (!job) return + setLoading(true) + fineTuneApi.listCheckpoints(job.id).then(r => { + setCheckpoints(r.checkpoints || []) + }).catch(() => {}).finally(() => setLoading(false)) + }, [job]) + + if (!job) return null + if (loading) return
Loading checkpoints...
+ if (checkpoints.length === 0) return null + + return ( +
+

+ + Checkpoints +

+
+ + + + + + + + + + + + + {checkpoints.map(cp => ( + + + + + + + + + ))} + +
StepEpochLossCreatedPathActions
{cp.step}{cp.epoch?.toFixed(2)}{cp.loss?.toFixed(4)}{cp.created_at} + {cp.path} + + + +
+
+
+ ) +} + +const QUANT_PRESETS = ['q4_k_m', 'q5_k_m', 'q8_0', 'f16', 'q4_0', 'q5_0'] + +function ExportPanel({ job, prefilledCheckpoint }) { + const [checkpoints, setCheckpoints] = useState([]) + const [exportFormat, setExportFormat] = useState('lora') + const [quantMethod, setQuantMethod] = useState('q4_k_m') + const [modelName, setModelName] = useState('') + const [selectedCheckpoint, setSelectedCheckpoint] = useState('') + const [exporting, setExporting] = useState(false) + const [message, setMessage] = useState('') + const [exportedModelName, setExportedModelName] = useState('') + const pollRef = useRef(null) + + useEffect(() => { + if (!job) return + fineTuneApi.listCheckpoints(job.id).then(r => { + setCheckpoints(r.checkpoints || []) + }).catch(() => {}) + }, [job]) + + // Apply prefilled checkpoint when set + useEffect(() => { + if (prefilledCheckpoint) { + setSelectedCheckpoint(prefilledCheckpoint.path || '') + } + }, [prefilledCheckpoint]) + + // Sync export state from job (e.g. on initial load or job list refresh) + useEffect(() => { + if (!job) return + if (job.export_status === 'exporting') { + setExporting(true) + setMessage('Export in progress...') + } else if (job.export_status === 'completed' && job.export_model_name) { + setExporting(false) + setExportedModelName(job.export_model_name) + setMessage(`Model exported and registered as "${job.export_model_name}"`) + } else if (job.export_status === 'failed') { + setExporting(false) + setMessage(`Export failed: ${job.export_message || 'unknown error'}`) + } + }, [job?.export_status, job?.export_model_name, job?.export_message]) + + // Poll for export completion + useEffect(() => { + if (!exporting || !job) return + + pollRef.current = setInterval(async () => { + try { + const updated = await fineTuneApi.getJob(job.id) + if (updated.export_status === 'completed') { + setExporting(false) + const name = updated.export_model_name || modelName || 'exported model' + setExportedModelName(name) + setMessage(`Model exported and registered as "${name}"`) + clearInterval(pollRef.current) + } else if (updated.export_status === 'failed') { + setExporting(false) + setMessage(`Export failed: ${updated.export_message || 'unknown error'}`) + clearInterval(pollRef.current) + } + } catch (_) {} + }, 3000) + + return () => clearInterval(pollRef.current) + }, [exporting, job?.id]) + + const handleExport = async () => { + setExporting(true) + setMessage('Export in progress...') + setExportedModelName('') + try { + await fineTuneApi.exportModel(job.id, { + name: modelName || undefined, + checkpoint_path: selectedCheckpoint || job.output_dir, + export_format: exportFormat, + quantization_method: exportFormat === 'gguf' ? quantMethod : '', + model: job.model, + }) + // Polling will pick up completion/failure + } catch (e) { + setMessage(`Export failed: ${e.message}`) + setExporting(false) + } + } + + // Show export panel for completed, stopped, and failed jobs (checkpoints may exist) + if (!job || !['completed', 'stopped', 'failed'].includes(job.status)) return null + + return ( +
+

+ + Export Model +

+ + {checkpoints.length > 0 && ( +
+ + +
+ )} + +
+
+ + +
+ {exportFormat === 'gguf' && ( +
+ + setQuantMethod(e.target.value)} + placeholder="e.g. q4_k_m, bf16, f32" + className="input" + /> + + {QUANT_PRESETS.map(q => ( + +
+ )} +
+ +
+ + setModelName(e.target.value)} + placeholder="e.g. my-finetuned-model" + className="input" + /> +
+ + + + {message && ( +
+ {message} + {exportedModelName && !message.includes('failed') && ( + + + Chat with {exportedModelName} + + + )} +
+ )} +
+ ) +} + +export default function FineTune() { + const [jobs, setJobs] = useState([]) + const [selectedJob, setSelectedJob] = useState(null) + const [showForm, setShowForm] = useState(false) + const [loading, setLoading] = useState(false) + const [error, setError] = useState('') + const [backends, setBackends] = useState([]) + const [exportCheckpoint, setExportCheckpoint] = useState(null) + + // Form state + const [model, setModel] = useState('') + const [backend, setBackend] = useState('') + const [trainingMethod, setTrainingMethod] = useState('sft') + const [trainingType, setTrainingType] = useState('lora') + const [datasetSource, setDatasetSource] = useState('') + const [datasetFile, setDatasetFile] = useState(null) + const [datasetSplit, setDatasetSplit] = useState('') + const [numEpochs, setNumEpochs] = useState(3) + const [batchSize, setBatchSize] = useState(2) + const [learningRate, setLearningRate] = useState(0.0002) + const [learningRateText, setLearningRateText] = useState('0.0002') + const [adapterRank, setAdapterRank] = useState(16) + const [adapterAlpha, setAdapterAlpha] = useState(16) + const [adapterDropout, setAdapterDropout] = useState(0) + const [targetModules, setTargetModules] = useState('') + const [gradAccum, setGradAccum] = useState(4) + const [warmupSteps, setWarmupSteps] = useState(5) + const [maxSteps, setMaxSteps] = useState(0) + const [saveSteps, setSaveSteps] = useState(500) + const [weightDecay, setWeightDecay] = useState(0) + const [maxSeqLength, setMaxSeqLength] = useState(2048) + const [optimizer, setOptimizer] = useState('adamw_torch') + const [gradCheckpointing, setGradCheckpointing] = useState(false) + const [seed, setSeed] = useState(0) + const [mixedPrecision, setMixedPrecision] = useState('') + const [extraOptions, setExtraOptions] = useState([]) + const [hfToken, setHfToken] = useState('') + const [showAdvanced, setShowAdvanced] = useState(false) + const [resumeFromCheckpoint, setResumeFromCheckpoint] = useState('') + const [saveTotalLimit, setSaveTotalLimit] = useState(0) + + const loadJobs = useCallback(async () => { + try { + const data = await fineTuneApi.listJobs() + setJobs(data || []) + } catch (_) {} + }, []) + + useEffect(() => { + loadJobs() + const interval = setInterval(loadJobs, 10000) + return () => clearInterval(interval) + }, [loadJobs]) + + useEffect(() => { + fineTuneApi.listBackends() + .then(data => { + const names = data && data.length > 0 ? data.map(b => b.name) : FALLBACK_BACKENDS + setBackends(names) + setBackend(prev => prev || names[0] || '') + }) + .catch(() => { + setBackends(FALLBACK_BACKENDS) + setBackend(prev => prev || FALLBACK_BACKENDS[0]) + }) + }, []) + + const handleSubmit = async (e) => { + e.preventDefault() + setLoading(true) + setError('') + + try { + let dsSource = datasetSource + if (datasetFile) { + const result = await fineTuneApi.uploadDataset(datasetFile) + dsSource = result.path + } + + const extra = {} + if (maxSeqLength) extra.max_seq_length = String(maxSeqLength) + if (hfToken.trim()) extra.hf_token = hfToken.trim() + if (saveTotalLimit > 0) extra.save_total_limit = String(saveTotalLimit) + for (const { key, value } of extraOptions) { + if (key.trim()) extra[key.trim()] = value + } + + const isAdapter = ['lora', 'loha', 'lokr'].includes(trainingType) + + const req = { + model, + backend, + training_method: trainingMethod, + training_type: trainingType, + dataset_source: dsSource, + dataset_split: datasetSplit || undefined, + num_epochs: numEpochs, + batch_size: batchSize, + learning_rate: learningRate, + adapter_rank: isAdapter ? adapterRank : 0, + adapter_alpha: isAdapter ? adapterAlpha : 0, + adapter_dropout: isAdapter && adapterDropout > 0 ? adapterDropout : undefined, + target_modules: isAdapter && targetModules.trim() ? targetModules.split(',').map(s => s.trim()) : undefined, + gradient_accumulation_steps: gradAccum, + warmup_steps: warmupSteps, + max_steps: maxSteps > 0 ? maxSteps : undefined, + save_steps: saveSteps > 0 ? saveSteps : undefined, + weight_decay: weightDecay > 0 ? weightDecay : undefined, + gradient_checkpointing: gradCheckpointing, + optimizer, + seed: seed > 0 ? seed : undefined, + mixed_precision: mixedPrecision || undefined, + resume_from_checkpoint: resumeFromCheckpoint || undefined, + extra_options: Object.keys(extra).length > 0 ? extra : undefined, + } + + const resp = await fineTuneApi.startJob(req) + setShowForm(false) + setResumeFromCheckpoint('') + await loadJobs() + + const newJob = { ...req, id: resp.id, status: 'queued', created_at: new Date().toISOString() } + setSelectedJob(newJob) + } catch (err) { + setError(err.message) + } + setLoading(false) + } + + const handleStop = async (jobId) => { + try { + await fineTuneApi.stopJob(jobId, true) + await loadJobs() + } catch (err) { + setError(err.message) + } + } + + const isAdapter = ['lora', 'loha', 'lokr'].includes(trainingType) + + const getFormConfig = () => { + const extra = {} + for (const { key, value } of extraOptions) { + if (key.trim()) extra[key.trim()] = value + } + return { + model, + backend, + training_method: trainingMethod, + training_type: trainingType, + adapter_rank: adapterRank, + adapter_alpha: adapterAlpha, + adapter_dropout: adapterDropout, + target_modules: targetModules.trim() ? targetModules.split(',').map(s => s.trim()) : [], + dataset_source: datasetSource, + dataset_split: datasetSplit, + num_epochs: numEpochs, + batch_size: batchSize, + learning_rate: learningRate, + gradient_accumulation_steps: gradAccum, + warmup_steps: warmupSteps, + max_steps: maxSteps, + save_steps: saveSteps, + weight_decay: weightDecay, + gradient_checkpointing: gradCheckpointing, + optimizer, + seed, + mixed_precision: mixedPrecision, + max_seq_length: maxSeqLength, + extra_options: Object.keys(extra).length > 0 ? extra : {}, + } + } + + const applyFormConfig = (config) => { + if (config.model != null) setModel(config.model) + if (config.backend != null) setBackend(config.backend) + if (config.training_method != null) setTrainingMethod(config.training_method) + if (config.training_type != null) setTrainingType(config.training_type) + if (config.adapter_rank != null) setAdapterRank(Number(config.adapter_rank)) + if (config.adapter_alpha != null) setAdapterAlpha(Number(config.adapter_alpha)) + if (config.adapter_dropout != null) setAdapterDropout(Number(config.adapter_dropout)) + if (config.target_modules != null) { + const modules = Array.isArray(config.target_modules) + ? config.target_modules.join(', ') + : String(config.target_modules) + setTargetModules(modules) + } + if (config.dataset_source != null) setDatasetSource(config.dataset_source) + if (config.dataset_split != null) setDatasetSplit(config.dataset_split) + if (config.num_epochs != null) setNumEpochs(Number(config.num_epochs)) + if (config.batch_size != null) setBatchSize(Number(config.batch_size)) + if (config.learning_rate != null) { setLearningRate(Number(config.learning_rate)); setLearningRateText(String(config.learning_rate)) } + if (config.gradient_accumulation_steps != null) setGradAccum(Number(config.gradient_accumulation_steps)) + if (config.warmup_steps != null) setWarmupSteps(Number(config.warmup_steps)) + if (config.max_steps != null) setMaxSteps(Number(config.max_steps)) + if (config.save_steps != null) setSaveSteps(Number(config.save_steps)) + if (config.weight_decay != null) setWeightDecay(Number(config.weight_decay)) + if (config.gradient_checkpointing != null) setGradCheckpointing(Boolean(config.gradient_checkpointing)) + if (config.optimizer != null) setOptimizer(config.optimizer) + if (config.seed != null) setSeed(Number(config.seed)) + if (config.mixed_precision != null) setMixedPrecision(config.mixed_precision) + + // Handle max_seq_length: top-level field or inside extra_options + if (config.max_seq_length != null) { + setMaxSeqLength(Number(config.max_seq_length)) + } else if (config.extra_options?.max_seq_length != null) { + setMaxSeqLength(Number(config.extra_options.max_seq_length)) + } + + // Handle save_total_limit from extra_options + if (config.extra_options?.save_total_limit != null) { + setSaveTotalLimit(Number(config.extra_options.save_total_limit)) + } + + // Convert extra_options object to [{key, value}] entries, filtering out handled keys + if (config.extra_options && typeof config.extra_options === 'object') { + const entries = Object.entries(config.extra_options) + .filter(([k]) => !['max_seq_length', 'save_total_limit', 'hf_token'].includes(k)) + .map(([key, value]) => ({ key, value: String(value) })) + setExtraOptions(entries) + } + } + + const handleExportConfig = () => { + const config = getFormConfig() + const json = JSON.stringify(config, null, 2) + const blob = new Blob([json], { type: 'application/json' }) + const url = URL.createObjectURL(blob) + const a = document.createElement('a') + a.href = url + a.download = 'finetune-config.json' + document.body.appendChild(a) + a.click() + document.body.removeChild(a) + URL.revokeObjectURL(url) + } + + const handleImportConfig = () => { + const input = document.createElement('input') + input.type = 'file' + input.accept = '.json' + input.onchange = (e) => { + const file = e.target.files[0] + if (!file) return + const reader = new FileReader() + reader.onload = (ev) => { + try { + const config = JSON.parse(ev.target.result) + applyFormConfig(config) + setShowForm(true) + setError('') + } catch { + setError('Failed to parse config file. Please ensure it is valid JSON.') + } + } + reader.readAsText(file) + } + input.click() + } + + const handleUseConfig = (job) => { + // Prefer the stored config if available, otherwise use the job fields + applyFormConfig(job.config || job) + setResumeFromCheckpoint('') + setShowForm(true) + } + + const handleResumeFromCheckpoint = (checkpoint) => { + if (!selectedJob) return + // Apply the original job's config + applyFormConfig(selectedJob.config || selectedJob) + setResumeFromCheckpoint(checkpoint.path) + setShowAdvanced(true) + setShowForm(true) + } + + const handleExportCheckpoint = (checkpoint) => { + setExportCheckpoint(checkpoint) + } + + return ( +
+
+
+

Fine-Tuning

+

Create and manage fine-tuning jobs

+
+
+ + +
+
+ + {error && ( +
+ {error} +
+ )} + + {showForm && ( +
+ + {resumeFromCheckpoint && ( +
+ + + Resuming from checkpoint: {resumeFromCheckpoint} + + +
+ )} + + +
+
+ + +
+
+ + +
+
+ + setModel(e.target.value)} placeholder="e.g. unsloth/tinyllama-bnb-4bit" className="input" required /> +
+
+
+ + setHfToken(e.target.value)} placeholder="hf_..." className="input" /> +
+
+ + +
+
+ + +
+ {isAdapter && ( + <> +
+ + setAdapterRank(Number(e.target.value))} className="input" min={1} /> +
+
+ + setAdapterAlpha(Number(e.target.value))} className="input" min={1} /> +
+
+ + setAdapterDropout(Number(e.target.value))} className="input" min={0} max={1} step={0.05} /> +
+ + )} +
+ {isAdapter && ( +
+ + setTargetModules(e.target.value)} placeholder="e.g. q_proj, v_proj, k_proj, o_proj" className="input" /> +
+ )} +
+ + +
+
+ + setDatasetSource(e.target.value)} placeholder="e.g. tatsu-lab/alpaca" className="input" /> +
+
+ + setDatasetSplit(e.target.value)} placeholder="e.g. train" className="input" /> +
+
+ + setDatasetFile(e.target.files[0])} accept=".json,.jsonl,.csv" className="input" style={{ padding: '6px' }} /> +
+
+
+ + +
+
+ + setNumEpochs(Number(e.target.value))} className="input" min={1} /> +
+
+ + setBatchSize(Number(e.target.value))} className="input" min={1} /> +
+
+ + { + setLearningRateText(e.target.value) + const parsed = Number(e.target.value) + if (!isNaN(parsed) && parsed > 0) setLearningRate(parsed) + }} className="input" placeholder="e.g. 5e-5 or 0.00005" /> +
+
+ + setGradAccum(Number(e.target.value))} className="input" min={1} /> +
+
+ + setWarmupSteps(Number(e.target.value))} className="input" min={0} /> +
+
+ + setMaxSeqLength(Number(e.target.value))} className="input" min={64} /> +
+
+ + +
+
+ +
+
+
+ + {/* Collapsible advanced section */} +
+ + + {showAdvanced && ( +
+
+
+ + setMaxSteps(Number(e.target.value))} className="input" min={0} /> +
+
+ + setSaveSteps(Number(e.target.value))} className="input" min={0} /> +
+
+ + setSaveTotalLimit(Number(e.target.value))} className="input" min={0} /> +
+
+ + setWeightDecay(Number(e.target.value))} className="input" min={0} step={0.01} /> +
+
+ + setSeed(Number(e.target.value))} className="input" min={0} /> +
+
+ + +
+
+ + {resumeFromCheckpoint && ( +
+ +
+ setResumeFromCheckpoint(e.target.value)} className="input" style={{ flex: 1 }} /> + +
+
+ )} + +
+ + +
+
+ )} +
+ +
+ + +
+ + )} + + {/* Jobs list */} +
+
+

Jobs

+ {jobs.length === 0 ? ( +
+
+

No fine-tuning jobs yet

+

Click "New Job" to get started

+
+ ) : ( + jobs.map(job => ( + + )) + )} +
+ + {selectedJob && ( +
+ + + +
+ )} +
+
+ ) +} diff --git a/core/http/routes/finetuning.go b/core/http/routes/finetuning.go new file mode 100644 index 000000000..3df8f9f0f --- /dev/null +++ b/core/http/routes/finetuning.go @@ -0,0 +1,40 @@ +package routes + +import ( + "net/http" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/localai" + "github.com/mudler/LocalAI/core/services" +) + +// RegisterFineTuningRoutes registers fine-tuning API routes. +func RegisterFineTuningRoutes(e *echo.Echo, ftService *services.FineTuneService, appConfig *config.ApplicationConfig, fineTuningMw echo.MiddlewareFunc) { + if ftService == nil { + return + } + + // Service readiness middleware + readyMw := func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if ftService == nil { + return c.JSON(http.StatusServiceUnavailable, map[string]string{ + "error": "fine-tuning service is not available", + }) + } + return next(c) + } + } + + ft := e.Group("/api/fine-tuning", readyMw, fineTuningMw) + ft.GET("/backends", localai.ListFineTuneBackendsEndpoint(appConfig)) + ft.POST("/jobs", localai.StartFineTuneJobEndpoint(ftService)) + ft.GET("/jobs", localai.ListFineTuneJobsEndpoint(ftService)) + ft.GET("/jobs/:id", localai.GetFineTuneJobEndpoint(ftService)) + ft.DELETE("/jobs/:id", localai.StopFineTuneJobEndpoint(ftService)) + ft.GET("/jobs/:id/progress", localai.FineTuneProgressEndpoint(ftService)) + ft.GET("/jobs/:id/checkpoints", localai.ListCheckpointsEndpoint(ftService)) + ft.POST("/jobs/:id/export", localai.ExportModelEndpoint(ftService)) + ft.POST("/datasets", localai.UploadDatasetEndpoint(ftService)) +} diff --git a/core/schema/finetune.go b/core/schema/finetune.go new file mode 100644 index 000000000..cc1b88460 --- /dev/null +++ b/core/schema/finetune.go @@ -0,0 +1,99 @@ +package schema + +// FineTuneJobRequest is the REST API request to start a fine-tuning job. +type FineTuneJobRequest struct { + Model string `json:"model"` + Backend string `json:"backend"` // "unsloth", "trl" + TrainingType string `json:"training_type,omitempty"` // lora, loha, lokr, full + TrainingMethod string `json:"training_method,omitempty"` // sft, dpo, grpo, rloo, reward, kto, orpo + + // Adapter config + AdapterRank int32 `json:"adapter_rank,omitempty"` + AdapterAlpha int32 `json:"adapter_alpha,omitempty"` + AdapterDropout float32 `json:"adapter_dropout,omitempty"` + TargetModules []string `json:"target_modules,omitempty"` + + // Training hyperparameters + LearningRate float32 `json:"learning_rate,omitempty"` + NumEpochs int32 `json:"num_epochs,omitempty"` + BatchSize int32 `json:"batch_size,omitempty"` + GradientAccumulationSteps int32 `json:"gradient_accumulation_steps,omitempty"` + WarmupSteps int32 `json:"warmup_steps,omitempty"` + MaxSteps int32 `json:"max_steps,omitempty"` + SaveSteps int32 `json:"save_steps,omitempty"` + WeightDecay float32 `json:"weight_decay,omitempty"` + GradientCheckpointing bool `json:"gradient_checkpointing,omitempty"` + Optimizer string `json:"optimizer,omitempty"` + Seed int32 `json:"seed,omitempty"` + MixedPrecision string `json:"mixed_precision,omitempty"` + + // Dataset + DatasetSource string `json:"dataset_source"` + DatasetSplit string `json:"dataset_split,omitempty"` + + // Resume from a checkpoint + ResumeFromCheckpoint string `json:"resume_from_checkpoint,omitempty"` + + // Backend-specific and method-specific options + ExtraOptions map[string]string `json:"extra_options,omitempty"` +} + +// FineTuneJob represents a fine-tuning job with its current state. +type FineTuneJob struct { + ID string `json:"id"` + UserID string `json:"user_id,omitempty"` + Model string `json:"model"` + Backend string `json:"backend"` + TrainingType string `json:"training_type"` + TrainingMethod string `json:"training_method"` + Status string `json:"status"` // queued, loading_model, loading_dataset, training, saving, completed, failed, stopped + Message string `json:"message,omitempty"` + OutputDir string `json:"output_dir"` + ExtraOptions map[string]string `json:"extra_options,omitempty"` + CreatedAt string `json:"created_at"` + + // Export state (tracked separately from training status) + ExportStatus string `json:"export_status,omitempty"` // "", "exporting", "completed", "failed" + ExportMessage string `json:"export_message,omitempty"` + ExportModelName string `json:"export_model_name,omitempty"` // registered model name after export + + // Full config for resume/reuse + Config *FineTuneJobRequest `json:"config,omitempty"` +} + +// FineTuneJobResponse is the REST API response when creating a job. +type FineTuneJobResponse struct { + ID string `json:"id"` + Status string `json:"status"` + Message string `json:"message"` +} + +// FineTuneProgressEvent is an SSE event for training progress. +type FineTuneProgressEvent struct { + JobID string `json:"job_id"` + CurrentStep int32 `json:"current_step"` + TotalSteps int32 `json:"total_steps"` + CurrentEpoch float32 `json:"current_epoch"` + TotalEpochs float32 `json:"total_epochs"` + Loss float32 `json:"loss"` + LearningRate float32 `json:"learning_rate"` + GradNorm float32 `json:"grad_norm"` + EvalLoss float32 `json:"eval_loss"` + EtaSeconds float32 `json:"eta_seconds"` + ProgressPercent float32 `json:"progress_percent"` + Status string `json:"status"` + Message string `json:"message,omitempty"` + CheckpointPath string `json:"checkpoint_path,omitempty"` + SamplePath string `json:"sample_path,omitempty"` + ExtraMetrics map[string]float32 `json:"extra_metrics,omitempty"` +} + +// ExportRequest is the REST API request to export a model. +type ExportRequest struct { + Name string `json:"name,omitempty"` // model name for LocalAI (auto-generated if empty) + CheckpointPath string `json:"checkpoint_path"` + ExportFormat string `json:"export_format"` // lora, merged_16bit, merged_4bit, gguf + QuantizationMethod string `json:"quantization_method"` // for GGUF: q4_k_m, q5_k_m, q8_0, f16 + Model string `json:"model,omitempty"` // base model name for merge + ExtraOptions map[string]string `json:"extra_options,omitempty"` +} diff --git a/core/services/finetune.go b/core/services/finetune.go new file mode 100644 index 000000000..a8e265a75 --- /dev/null +++ b/core/services/finetune.go @@ -0,0 +1,564 @@ +package services + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery/importers" + "github.com/mudler/LocalAI/core/schema" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/utils" + "github.com/mudler/xlog" + "gopkg.in/yaml.v3" +) + +// FineTuneService manages fine-tuning jobs and their lifecycle. +type FineTuneService struct { + appConfig *config.ApplicationConfig + modelLoader *model.ModelLoader + configLoader *config.ModelConfigLoader + + mu sync.Mutex + jobs map[string]*schema.FineTuneJob +} + +// NewFineTuneService creates a new FineTuneService. +func NewFineTuneService( + appConfig *config.ApplicationConfig, + modelLoader *model.ModelLoader, + configLoader *config.ModelConfigLoader, +) *FineTuneService { + s := &FineTuneService{ + appConfig: appConfig, + modelLoader: modelLoader, + configLoader: configLoader, + jobs: make(map[string]*schema.FineTuneJob), + } + s.loadAllJobs() + return s +} + +// fineTuneBaseDir returns the base directory for fine-tune job data. +func (s *FineTuneService) fineTuneBaseDir() string { + return filepath.Join(s.appConfig.DataPath, "fine-tune") +} + +// jobDir returns the directory for a specific job. +func (s *FineTuneService) jobDir(jobID string) string { + return filepath.Join(s.fineTuneBaseDir(), jobID) +} + +// saveJobState persists a job's state to disk as state.json. +func (s *FineTuneService) saveJobState(job *schema.FineTuneJob) { + dir := s.jobDir(job.ID) + if err := os.MkdirAll(dir, 0750); err != nil { + xlog.Error("Failed to create job directory", "job_id", job.ID, "error", err) + return + } + + data, err := json.MarshalIndent(job, "", " ") + if err != nil { + xlog.Error("Failed to marshal job state", "job_id", job.ID, "error", err) + return + } + + statePath := filepath.Join(dir, "state.json") + if err := os.WriteFile(statePath, data, 0640); err != nil { + xlog.Error("Failed to write job state", "job_id", job.ID, "error", err) + } +} + +// loadAllJobs scans the fine-tune directory for persisted jobs and loads them. +func (s *FineTuneService) loadAllJobs() { + baseDir := s.fineTuneBaseDir() + entries, err := os.ReadDir(baseDir) + if err != nil { + // Directory doesn't exist yet — that's fine + return + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + statePath := filepath.Join(baseDir, entry.Name(), "state.json") + data, err := os.ReadFile(statePath) + if err != nil { + continue + } + + var job schema.FineTuneJob + if err := json.Unmarshal(data, &job); err != nil { + xlog.Warn("Failed to parse job state", "path", statePath, "error", err) + continue + } + + // Jobs that were running when we shut down are now stale + if job.Status == "queued" || job.Status == "loading_model" || job.Status == "loading_dataset" || job.Status == "training" || job.Status == "saving" { + job.Status = "stopped" + job.Message = "Server restarted while job was running" + } + + // Exports that were in progress are now stale + if job.ExportStatus == "exporting" { + job.ExportStatus = "failed" + job.ExportMessage = "Server restarted while export was running" + } + + s.jobs[job.ID] = &job + } + + if len(s.jobs) > 0 { + xlog.Info("Loaded persisted fine-tune jobs", "count", len(s.jobs)) + } +} + +// StartJob starts a new fine-tuning job. +func (s *FineTuneService) StartJob(ctx context.Context, userID string, req schema.FineTuneJobRequest) (*schema.FineTuneJobResponse, error) { + s.mu.Lock() + defer s.mu.Unlock() + + jobID := uuid.New().String() + + backendName := req.Backend + if backendName == "" { + backendName = "trl" + } + + // Always use DataPath for output — not user-configurable + outputDir := filepath.Join(s.fineTuneBaseDir(), jobID) + + // Build gRPC request + grpcReq := &pb.FineTuneRequest{ + Model: req.Model, + TrainingType: req.TrainingType, + TrainingMethod: req.TrainingMethod, + AdapterRank: req.AdapterRank, + AdapterAlpha: req.AdapterAlpha, + AdapterDropout: req.AdapterDropout, + TargetModules: req.TargetModules, + LearningRate: req.LearningRate, + NumEpochs: req.NumEpochs, + BatchSize: req.BatchSize, + GradientAccumulationSteps: req.GradientAccumulationSteps, + WarmupSteps: req.WarmupSteps, + MaxSteps: req.MaxSteps, + SaveSteps: req.SaveSteps, + WeightDecay: req.WeightDecay, + GradientCheckpointing: req.GradientCheckpointing, + Optimizer: req.Optimizer, + Seed: req.Seed, + MixedPrecision: req.MixedPrecision, + DatasetSource: req.DatasetSource, + DatasetSplit: req.DatasetSplit, + OutputDir: outputDir, + JobId: jobID, + ResumeFromCheckpoint: req.ResumeFromCheckpoint, + ExtraOptions: req.ExtraOptions, + } + + // Load the fine-tuning backend + backendModel, err := s.modelLoader.Load( + model.WithBackendString(backendName), + model.WithModel(backendName), + model.WithModelID(backendName+"-finetune"), + ) + if err != nil { + return nil, fmt.Errorf("failed to load backend %s: %w", backendName, err) + } + + // Start fine-tuning via gRPC + result, err := backendModel.StartFineTune(ctx, grpcReq) + if err != nil { + return nil, fmt.Errorf("failed to start fine-tuning: %w", err) + } + if !result.Success { + return nil, fmt.Errorf("fine-tuning failed to start: %s", result.Message) + } + + // Track the job + job := &schema.FineTuneJob{ + ID: jobID, + UserID: userID, + Model: req.Model, + Backend: backendName, + TrainingType: req.TrainingType, + TrainingMethod: req.TrainingMethod, + Status: "queued", + OutputDir: outputDir, + ExtraOptions: req.ExtraOptions, + CreatedAt: time.Now().UTC().Format(time.RFC3339), + Config: &req, + } + s.jobs[jobID] = job + s.saveJobState(job) + + return &schema.FineTuneJobResponse{ + ID: jobID, + Status: "queued", + Message: result.Message, + }, nil +} + +// GetJob returns a fine-tuning job by ID. +func (s *FineTuneService) GetJob(userID, jobID string) (*schema.FineTuneJob, error) { + s.mu.Lock() + defer s.mu.Unlock() + + job, ok := s.jobs[jobID] + if !ok { + return nil, fmt.Errorf("job not found: %s", jobID) + } + if userID != "" && job.UserID != userID { + return nil, fmt.Errorf("job not found: %s", jobID) + } + return job, nil +} + +// ListJobs returns all jobs for a user. +func (s *FineTuneService) ListJobs(userID string) []*schema.FineTuneJob { + s.mu.Lock() + defer s.mu.Unlock() + + var result []*schema.FineTuneJob + for _, job := range s.jobs { + if userID == "" || job.UserID == userID { + result = append(result, job) + } + } + return result +} + +// StopJob stops a running fine-tuning job. +func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, saveCheckpoint bool) error { + s.mu.Lock() + job, ok := s.jobs[jobID] + if !ok { + s.mu.Unlock() + return fmt.Errorf("job not found: %s", jobID) + } + if userID != "" && job.UserID != userID { + s.mu.Unlock() + return fmt.Errorf("job not found: %s", jobID) + } + s.mu.Unlock() + + backendModel, err := s.modelLoader.Load( + model.WithBackendString(job.Backend), + model.WithModel(job.Backend), + model.WithModelID(job.Backend+"-finetune"), + ) + if err != nil { + return fmt.Errorf("failed to load backend: %w", err) + } + + _, err = backendModel.StopFineTune(ctx, &pb.FineTuneStopRequest{ + JobId: jobID, + SaveCheckpoint: saveCheckpoint, + }) + if err != nil { + return fmt.Errorf("failed to stop job: %w", err) + } + + s.mu.Lock() + job.Status = "stopped" + s.saveJobState(job) + s.mu.Unlock() + + return nil +} + +// StreamProgress opens a gRPC progress stream and calls the callback for each update. +func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.FineTuneProgressEvent)) error { + s.mu.Lock() + job, ok := s.jobs[jobID] + if !ok { + s.mu.Unlock() + return fmt.Errorf("job not found: %s", jobID) + } + if userID != "" && job.UserID != userID { + s.mu.Unlock() + return fmt.Errorf("job not found: %s", jobID) + } + s.mu.Unlock() + + backendModel, err := s.modelLoader.Load( + model.WithBackendString(job.Backend), + model.WithModel(job.Backend), + model.WithModelID(job.Backend+"-finetune"), + ) + if err != nil { + return fmt.Errorf("failed to load backend: %w", err) + } + + return backendModel.FineTuneProgress(ctx, &pb.FineTuneProgressRequest{ + JobId: jobID, + }, func(update *pb.FineTuneProgressUpdate) { + // Update job status and persist + s.mu.Lock() + if j, ok := s.jobs[jobID]; ok { + j.Status = update.Status + if update.Message != "" { + j.Message = update.Message + } + s.saveJobState(j) + } + s.mu.Unlock() + + // Convert extra metrics + extraMetrics := make(map[string]float32) + for k, v := range update.ExtraMetrics { + extraMetrics[k] = v + } + + event := &schema.FineTuneProgressEvent{ + JobID: update.JobId, + CurrentStep: update.CurrentStep, + TotalSteps: update.TotalSteps, + CurrentEpoch: update.CurrentEpoch, + TotalEpochs: update.TotalEpochs, + Loss: update.Loss, + LearningRate: update.LearningRate, + GradNorm: update.GradNorm, + EvalLoss: update.EvalLoss, + EtaSeconds: update.EtaSeconds, + ProgressPercent: update.ProgressPercent, + Status: update.Status, + Message: update.Message, + CheckpointPath: update.CheckpointPath, + SamplePath: update.SamplePath, + ExtraMetrics: extraMetrics, + } + callback(event) + }) +} + +// ListCheckpoints lists checkpoints for a job. +func (s *FineTuneService) ListCheckpoints(ctx context.Context, userID, jobID string) ([]*pb.CheckpointInfo, error) { + s.mu.Lock() + job, ok := s.jobs[jobID] + if !ok { + s.mu.Unlock() + return nil, fmt.Errorf("job not found: %s", jobID) + } + if userID != "" && job.UserID != userID { + s.mu.Unlock() + return nil, fmt.Errorf("job not found: %s", jobID) + } + s.mu.Unlock() + + backendModel, err := s.modelLoader.Load( + model.WithBackendString(job.Backend), + model.WithModel(job.Backend), + model.WithModelID(job.Backend+"-finetune"), + ) + if err != nil { + return nil, fmt.Errorf("failed to load backend: %w", err) + } + + resp, err := backendModel.ListCheckpoints(ctx, &pb.ListCheckpointsRequest{ + OutputDir: job.OutputDir, + }) + if err != nil { + return nil, fmt.Errorf("failed to list checkpoints: %w", err) + } + + return resp.Checkpoints, nil +} + +// sanitizeModelName replaces non-alphanumeric characters with hyphens and lowercases. +func sanitizeModelName(s string) string { + re := regexp.MustCompile(`[^a-zA-Z0-9\-]`) + s = re.ReplaceAllString(s, "-") + s = regexp.MustCompile(`-+`).ReplaceAllString(s, "-") + s = strings.Trim(s, "-") + return strings.ToLower(s) +} + +// ExportModel starts an async model export from a checkpoint and returns the intended model name immediately. +func (s *FineTuneService) ExportModel(ctx context.Context, userID, jobID string, req schema.ExportRequest) (string, error) { + s.mu.Lock() + job, ok := s.jobs[jobID] + if !ok { + s.mu.Unlock() + return "", fmt.Errorf("job not found: %s", jobID) + } + if userID != "" && job.UserID != userID { + s.mu.Unlock() + return "", fmt.Errorf("job not found: %s", jobID) + } + if job.ExportStatus == "exporting" { + s.mu.Unlock() + return "", fmt.Errorf("export already in progress for job %s", jobID) + } + s.mu.Unlock() + + // Compute model name + modelName := req.Name + if modelName == "" { + base := sanitizeModelName(job.Model) + if base == "" { + base = "model" + } + shortID := jobID + if len(shortID) > 8 { + shortID = shortID[:8] + } + modelName = base + "-ft-" + shortID + } + + // Compute output path in models directory + modelsPath := s.appConfig.SystemState.Model.ModelsPath + outputPath := filepath.Join(modelsPath, modelName) + + // Check for name collision (synchronous — fast validation) + configPath := filepath.Join(modelsPath, modelName+".yaml") + if err := utils.VerifyPath(modelName+".yaml", modelsPath); err != nil { + return "", fmt.Errorf("invalid model name: %w", err) + } + if _, err := os.Stat(configPath); err == nil { + return "", fmt.Errorf("model %q already exists, choose a different name", modelName) + } + + // Create output directory + if err := os.MkdirAll(outputPath, 0750); err != nil { + return "", fmt.Errorf("failed to create output directory: %w", err) + } + + // Set export status to "exporting" and persist + s.mu.Lock() + job.ExportStatus = "exporting" + job.ExportMessage = "" + job.ExportModelName = "" + s.saveJobState(job) + s.mu.Unlock() + + // Launch the export in a background goroutine + go func() { + backendModel, err := s.modelLoader.Load( + model.WithBackendString(job.Backend), + model.WithModel(job.Backend), + model.WithModelID(job.Backend+"-finetune"), + ) + if err != nil { + s.setExportFailed(job, fmt.Sprintf("failed to load backend: %v", err)) + return + } + + // Merge job's extra_options (contains hf_token from training) with request's + mergedOpts := make(map[string]string) + for k, v := range job.ExtraOptions { + mergedOpts[k] = v + } + for k, v := range req.ExtraOptions { + mergedOpts[k] = v // request overrides job + } + + grpcReq := &pb.ExportModelRequest{ + CheckpointPath: req.CheckpointPath, + OutputPath: outputPath, + ExportFormat: req.ExportFormat, + QuantizationMethod: req.QuantizationMethod, + Model: req.Model, + ExtraOptions: mergedOpts, + } + + result, err := backendModel.ExportModel(context.Background(), grpcReq) + if err != nil { + s.setExportFailed(job, fmt.Sprintf("export failed: %v", err)) + return + } + if !result.Success { + s.setExportFailed(job, fmt.Sprintf("export failed: %s", result.Message)) + return + } + + // Auto-import: detect format and generate config + cfg, err := importers.ImportLocalPath(outputPath, modelName) + if err != nil { + s.setExportFailed(job, fmt.Sprintf("model exported to %s but config generation failed: %v", outputPath, err)) + return + } + + cfg.Name = modelName + + // If base model not detected from files, use the job's model field + if cfg.Model == "" && job.Model != "" { + cfg.Model = job.Model + } + + // Write YAML config + yamlData, err := yaml.Marshal(cfg) + if err != nil { + s.setExportFailed(job, fmt.Sprintf("failed to marshal config: %v", err)) + return + } + if err := os.WriteFile(configPath, yamlData, 0644); err != nil { + s.setExportFailed(job, fmt.Sprintf("failed to write config file: %v", err)) + return + } + + // Reload configs so the model is immediately available + if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil { + xlog.Warn("Failed to reload configs after export", "error", err) + } + if err := s.configLoader.Preload(modelsPath); err != nil { + xlog.Warn("Failed to preload after export", "error", err) + } + + xlog.Info("Model exported and registered", "job_id", jobID, "model_name", modelName, "format", req.ExportFormat) + + s.mu.Lock() + job.ExportStatus = "completed" + job.ExportModelName = modelName + job.ExportMessage = "" + s.saveJobState(job) + s.mu.Unlock() + }() + + return modelName, nil +} + +// setExportFailed sets the export status to failed with a message. +func (s *FineTuneService) setExportFailed(job *schema.FineTuneJob, message string) { + xlog.Error("Export failed", "job_id", job.ID, "error", message) + s.mu.Lock() + job.ExportStatus = "failed" + job.ExportMessage = message + s.saveJobState(job) + s.mu.Unlock() +} + +// UploadDataset handles dataset file upload and returns the local path. +func (s *FineTuneService) UploadDataset(filename string, data []byte) (string, error) { + uploadDir := filepath.Join(s.fineTuneBaseDir(), "datasets") + if err := os.MkdirAll(uploadDir, 0750); err != nil { + return "", fmt.Errorf("failed to create dataset directory: %w", err) + } + + filePath := filepath.Join(uploadDir, uuid.New().String()[:8]+"-"+filename) + if err := os.WriteFile(filePath, data, 0640); err != nil { + return "", fmt.Errorf("failed to write dataset: %w", err) + } + + return filePath, nil +} + +// MarshalProgressEvent converts a progress event to JSON for SSE. +func MarshalProgressEvent(event *schema.FineTuneProgressEvent) (string, error) { + data, err := json.Marshal(event) + if err != nil { + return "", err + } + return string(data), nil +} diff --git a/docs/content/features/fine-tuning.md b/docs/content/features/fine-tuning.md new file mode 100644 index 000000000..d94fbfcae --- /dev/null +++ b/docs/content/features/fine-tuning.md @@ -0,0 +1,185 @@ ++++ +disableToc = false +title = "Fine-Tuning" +weight = 18 +url = '/features/fine-tuning/' ++++ + +LocalAI supports fine-tuning LLMs directly through the API and Web UI. Fine-tuning is powered by pluggable backends that implement a generic gRPC interface, allowing support for different training frameworks and model types. + +## Supported Backends + +| Backend | Domain | GPU Required | Training Methods | Adapter Types | +|---------|--------|-------------|-----------------|---------------| +| **unsloth** | LLM fine-tuning | Yes (CUDA) | SFT, GRPO | LoRA/QLoRA | +| **trl** | LLM fine-tuning | No (CPU or GPU) | SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO | LoRA, Full | + +## Enabling Fine-Tuning + +Fine-tuning is disabled by default. Enable it with: + +```bash +LOCALAI_ENABLE_FINETUNING=true local-ai +``` + +When authentication is enabled, fine-tuning is a per-user feature (default OFF). Admins can enable it for specific users via the user management API. + +## Quick Start + +### 1. Start a fine-tuning job + +```bash +curl -X POST http://localhost:8080/api/fine-tuning/jobs \ + -H "Content-Type: application/json" \ + -d '{ + "model": "unsloth/tinyllama-bnb-4bit", + "backend": "unsloth", + "training_method": "sft", + "training_type": "lora", + "dataset_source": "yahma/alpaca-cleaned", + "num_epochs": 1, + "batch_size": 2, + "learning_rate": 0.0002, + "adapter_rank": 16, + "adapter_alpha": 16, + "extra_options": { + "max_seq_length": "2048", + "load_in_4bit": "true" + } + }' +``` + +### 2. Monitor progress (SSE stream) + +```bash +curl -N http://localhost:8080/api/fine-tuning/jobs/{job_id}/progress +``` + +### 3. List checkpoints + +```bash +curl http://localhost:8080/api/fine-tuning/jobs/{job_id}/checkpoints +``` + +### 4. Export model + +```bash +curl -X POST http://localhost:8080/api/fine-tuning/jobs/{job_id}/export \ + -H "Content-Type: application/json" \ + -d '{ + "export_format": "gguf", + "quantization_method": "q4_k_m", + "output_path": "/models/my-finetuned-model" + }' +``` + +## API Reference + +### Endpoints + +| Method | Path | Description | +|--------|------|-------------| +| `POST` | `/api/fine-tuning/jobs` | Start a fine-tuning job | +| `GET` | `/api/fine-tuning/jobs` | List all jobs | +| `GET` | `/api/fine-tuning/jobs/:id` | Get job details | +| `DELETE` | `/api/fine-tuning/jobs/:id` | Stop a running job | +| `GET` | `/api/fine-tuning/jobs/:id/progress` | SSE progress stream | +| `GET` | `/api/fine-tuning/jobs/:id/checkpoints` | List checkpoints | +| `POST` | `/api/fine-tuning/jobs/:id/export` | Export model | +| `POST` | `/api/fine-tuning/datasets` | Upload dataset file | + +### Job Request Fields + +| Field | Type | Description | +|-------|------|-------------| +| `model` | string | HuggingFace model ID or local path (required) | +| `backend` | string | Backend name: `unsloth` or `trl` (default: `trl`) | +| `training_method` | string | `sft`, `dpo`, `grpo`, `rloo`, `reward`, `kto`, `orpo` | +| `training_type` | string | `lora` or `full` | +| `dataset_source` | string | HuggingFace dataset ID or local file path (required) | +| `adapter_rank` | int | LoRA rank (default: 16) | +| `adapter_alpha` | int | LoRA alpha (default: 16) | +| `num_epochs` | int | Number of training epochs (default: 3) | +| `batch_size` | int | Per-device batch size (default: 2) | +| `learning_rate` | float | Learning rate (default: 2e-4) | +| `gradient_accumulation_steps` | int | Gradient accumulation (default: 4) | +| `warmup_steps` | int | Warmup steps (default: 5) | +| `optimizer` | string | `adamw_torch`, `adamw_8bit`, `sgd`, `adafactor`, `prodigy` | +| `extra_options` | map | Backend-specific options (see below) | + +### Backend-Specific Options (`extra_options`) + +#### Unsloth + +| Key | Description | Default | +|-----|-------------|---------| +| `max_seq_length` | Maximum sequence length | `2048` | +| `load_in_4bit` | Load model in 4-bit quantization | `true` | +| `packing` | Enable sequence packing | `false` | +| `use_rslora` | Use Rank-Stabilized LoRA | `false` | + +#### TRL + +| Key | Description | Default | +|-----|-------------|---------| +| `max_seq_length` | Maximum sequence length | `512` | +| `packing` | Enable sequence packing | `false` | +| `trust_remote_code` | Trust remote code in model | `false` | +| `load_in_4bit` | Enable 4-bit quantization (GPU only) | `false` | + +#### DPO-specific (training_method=dpo) + +| Key | Description | Default | +|-----|-------------|---------| +| `beta` | KL penalty coefficient | `0.1` | +| `loss_type` | Loss type: `sigmoid`, `hinge`, `ipo` | `sigmoid` | +| `max_length` | Maximum sequence length | `512` | + +#### GRPO-specific (training_method=grpo) + +| Key | Description | Default | +|-----|-------------|---------| +| `num_generations` | Number of generations per prompt | `4` | +| `max_completion_length` | Max completion token length | `256` | + +### Export Formats + +| Format | Description | Notes | +|--------|-------------|-------| +| `lora` | LoRA adapter files | Smallest, requires base model | +| `merged_16bit` | Full model in 16-bit | Large but standalone | +| `merged_4bit` | Full model in 4-bit | Smaller, standalone | +| `gguf` | GGUF format | For llama.cpp, requires `quantization_method` | + +### GGUF Quantization Methods + +`q4_k_m`, `q5_k_m`, `q8_0`, `f16`, `q4_0`, `q5_0` + +## Web UI + +When fine-tuning is enabled, a "Fine-Tune" page appears in the sidebar under the Agents section. The UI provides: + +1. **Job Configuration** — Select backend, model, training method, adapter type, and hyperparameters +2. **Dataset Upload** — Upload local datasets or reference HuggingFace datasets +3. **Training Monitor** — Real-time loss chart, progress bar, metrics display +4. **Export** — Export trained models in various formats + +## Dataset Formats + +Datasets should follow standard HuggingFace formats: + +- **SFT**: Alpaca format (`instruction`, `input`, `output` fields) or ChatML/ShareGPT +- **DPO**: Preference pairs (`prompt`, `chosen`, `rejected` fields) +- **GRPO**: Prompts with reward signals + +Supported file formats: `.json`, `.jsonl`, `.csv` + +## Architecture + +Fine-tuning uses the same gRPC backend architecture as inference: + +1. **Proto layer**: `FineTuneRequest`, `FineTuneProgress` (streaming), `StopFineTune`, `ListCheckpoints`, `ExportModel` +2. **Python backends**: Each backend implements the gRPC interface with its specific training framework +3. **Go service**: Manages job lifecycle, routes API requests to backends +4. **REST API**: HTTP endpoints with SSE progress streaming +5. **React UI**: Configuration form, real-time training monitor, export panel