mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-16 21:08:16 -04:00
feat: add fine-tuning endpoint
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
205
core/gallery/importers/local.go
Normal file
205
core/gallery/importers/local.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// ImportLocalPath scans a local directory for exported model files and produces
|
||||
// a config.ModelConfig with the correct backend, model path, and options.
|
||||
// Paths in the returned config are relative to modelsPath when possible so that
|
||||
// the YAML config remains portable.
|
||||
//
|
||||
// Detection order:
|
||||
// 1. GGUF files (*.gguf) — uses llama-cpp backend
|
||||
// 2. LoRA adapter (adapter_config.json) — uses transformers backend with lora_adapter
|
||||
// 3. Merged model (*.safetensors or pytorch_model*.bin + config.json) — uses transformers backend
|
||||
func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) {
|
||||
// Make paths relative to the models directory (parent of dirPath)
|
||||
// so config YAML stays portable.
|
||||
modelsDir := filepath.Dir(dirPath)
|
||||
relPath := func(absPath string) string {
|
||||
if rel, err := filepath.Rel(modelsDir, absPath); err == nil {
|
||||
return rel
|
||||
}
|
||||
return absPath
|
||||
}
|
||||
|
||||
// 1. GGUF: check dirPath and dirPath_gguf/ (Unsloth convention)
|
||||
ggufFile := findGGUF(dirPath)
|
||||
if ggufFile == "" {
|
||||
ggufSubdir := dirPath + "_gguf"
|
||||
ggufFile = findGGUF(ggufSubdir)
|
||||
}
|
||||
if ggufFile != "" {
|
||||
xlog.Info("ImportLocalPath: detected GGUF model", "path", ggufFile)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Options: []string{"use_jinja:true"},
|
||||
}
|
||||
cfg.Model = relPath(ggufFile)
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.Description = buildDescription(dirPath, "GGUF")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// 2. LoRA adapter: look for adapter_config.json
|
||||
|
||||
adapterConfigPath := filepath.Join(dirPath, "adapter_config.json")
|
||||
if fileExists(adapterConfigPath) {
|
||||
xlog.Info("ImportLocalPath: detected LoRA adapter", "path", dirPath)
|
||||
baseModel := readBaseModel(dirPath)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
}
|
||||
cfg.Model = baseModel
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
|
||||
cfg.Description = buildDescription(dirPath, "LoRA adapter")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Also check for adapter_model.safetensors or adapter_model.bin without adapter_config.json
|
||||
if fileExists(filepath.Join(dirPath, "adapter_model.safetensors")) || fileExists(filepath.Join(dirPath, "adapter_model.bin")) {
|
||||
xlog.Info("ImportLocalPath: detected LoRA adapter (by model files)", "path", dirPath)
|
||||
baseModel := readBaseModel(dirPath)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
}
|
||||
cfg.Model = baseModel
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.LLMConfig.LoraAdapter = relPath(dirPath)
|
||||
cfg.Description = buildDescription(dirPath, "LoRA adapter")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// 3. Merged model: *.safetensors or pytorch_model*.bin + config.json
|
||||
if fileExists(filepath.Join(dirPath, "config.json")) && (hasFileWithSuffix(dirPath, ".safetensors") || hasFileWithPrefix(dirPath, "pytorch_model")) {
|
||||
xlog.Info("ImportLocalPath: detected merged model", "path", dirPath)
|
||||
cfg := &config.ModelConfig{
|
||||
Name: name,
|
||||
Backend: "transformers",
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
}
|
||||
cfg.Model = relPath(dirPath)
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.Description = buildDescription(dirPath, "merged model")
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("could not detect model format in directory %s", dirPath)
|
||||
}
|
||||
|
||||
// findGGUF returns the path to the first .gguf file found in dir, or "".
|
||||
func findGGUF(dir string) string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), ".gguf") {
|
||||
return filepath.Join(dir, e.Name())
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// readBaseModel reads the base model name from adapter_config.json or export_metadata.json.
|
||||
func readBaseModel(dirPath string) string {
|
||||
// Try adapter_config.json → base_model_name_or_path (TRL writes this)
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
|
||||
var ac map[string]any
|
||||
if json.Unmarshal(data, &ac) == nil {
|
||||
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
|
||||
return bm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try export_metadata.json → base_model (Unsloth writes this)
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
|
||||
var meta map[string]any
|
||||
if json.Unmarshal(data, &meta) == nil {
|
||||
if bm, ok := meta["base_model"].(string); ok && bm != "" {
|
||||
return bm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// buildDescription creates a human-readable description using available metadata.
|
||||
func buildDescription(dirPath, formatLabel string) string {
|
||||
base := ""
|
||||
|
||||
// Try adapter_config.json
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "adapter_config.json")); err == nil {
|
||||
var ac map[string]any
|
||||
if json.Unmarshal(data, &ac) == nil {
|
||||
if bm, ok := ac["base_model_name_or_path"].(string); ok && bm != "" {
|
||||
base = bm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try export_metadata.json
|
||||
if base == "" {
|
||||
if data, err := os.ReadFile(filepath.Join(dirPath, "export_metadata.json")); err == nil {
|
||||
var meta map[string]any
|
||||
if json.Unmarshal(data, &meta) == nil {
|
||||
if bm, ok := meta["base_model"].(string); ok && bm != "" {
|
||||
base = bm
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if base != "" {
|
||||
return fmt.Sprintf("Fine-tuned from %s (%s)", base, formatLabel)
|
||||
}
|
||||
return fmt.Sprintf("Fine-tuned model (%s)", formatLabel)
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
info, err := os.Stat(path)
|
||||
return err == nil && !info.IsDir()
|
||||
}
|
||||
|
||||
func hasFileWithSuffix(dir, suffix string) bool {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() && strings.HasSuffix(strings.ToLower(e.Name()), suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasFileWithPrefix(dir, prefix string) bool {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() && strings.HasPrefix(e.Name(), prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
148
core/gallery/importers/local_test.go
Normal file
148
core/gallery/importers/local_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
)
|
||||
|
||||
var _ = Describe("ImportLocalPath", func() {
|
||||
var tmpDir string
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpDir, err = os.MkdirTemp("", "importers-local-test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
Context("GGUF detection", func() {
|
||||
It("detects a GGUF file in the directory", func() {
|
||||
modelDir := filepath.Join(tmpDir, "my-model")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "model-q4_k_m.gguf"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("llama-cpp"))
|
||||
Expect(cfg.Model).To(ContainSubstring(".gguf"))
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
Expect(cfg.KnownUsecaseStrings).To(ContainElement("chat"))
|
||||
Expect(cfg.Options).To(ContainElement("use_jinja:true"))
|
||||
})
|
||||
|
||||
It("detects GGUF in _gguf subdirectory", func() {
|
||||
modelDir := filepath.Join(tmpDir, "my-model")
|
||||
ggufDir := modelDir + "_gguf"
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.MkdirAll(ggufDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(ggufDir, "model.gguf"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "my-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("llama-cpp"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("LoRA adapter detection", func() {
|
||||
It("detects LoRA adapter via adapter_config.json", func() {
|
||||
modelDir := filepath.Join(tmpDir, "lora-model")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
adapterConfig := map[string]any{
|
||||
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
||||
"peft_type": "LORA",
|
||||
}
|
||||
data, _ := json.Marshal(adapterConfig)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "lora-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("transformers"))
|
||||
Expect(cfg.Model).To(Equal("meta-llama/Llama-2-7b-hf"))
|
||||
Expect(cfg.LLMConfig.LoraAdapter).To(Equal(modelDir))
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
})
|
||||
|
||||
It("reads base model from export_metadata.json as fallback", func() {
|
||||
modelDir := filepath.Join(tmpDir, "lora-unsloth")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
adapterConfig := map[string]any{"peft_type": "LORA"}
|
||||
data, _ := json.Marshal(adapterConfig)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||
|
||||
metadata := map[string]any{"base_model": "unsloth/tinyllama-bnb-4bit"}
|
||||
data, _ = json.Marshal(metadata)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "export_metadata.json"), data, 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "lora-unsloth")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Model).To(Equal("unsloth/tinyllama-bnb-4bit"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Merged model detection", func() {
|
||||
It("detects merged model with safetensors + config.json", func() {
|
||||
modelDir := filepath.Join(tmpDir, "merged")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "merged")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("transformers"))
|
||||
Expect(cfg.Model).To(Equal(modelDir))
|
||||
Expect(cfg.TemplateConfig.UseTokenizerTemplate).To(BeTrue())
|
||||
})
|
||||
|
||||
It("detects merged model with pytorch_model files", func() {
|
||||
modelDir := filepath.Join(tmpDir, "merged-pt")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "config.json"), []byte("{}"), 0644)).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "pytorch_model-00001-of-00002.bin"), []byte("fake"), 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "merged-pt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Backend).To(Equal("transformers"))
|
||||
Expect(cfg.Model).To(Equal(modelDir))
|
||||
})
|
||||
})
|
||||
|
||||
Context("fallback", func() {
|
||||
It("returns error for empty directory", func() {
|
||||
modelDir := filepath.Join(tmpDir, "empty")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
_, err := importers.ImportLocalPath(modelDir, "empty")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("could not detect model format"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("description", func() {
|
||||
It("includes base model name in description", func() {
|
||||
modelDir := filepath.Join(tmpDir, "desc-test")
|
||||
Expect(os.MkdirAll(modelDir, 0755)).To(Succeed())
|
||||
|
||||
adapterConfig := map[string]any{
|
||||
"base_model_name_or_path": "TinyLlama/TinyLlama-1.1B",
|
||||
}
|
||||
data, _ := json.Marshal(adapterConfig)
|
||||
Expect(os.WriteFile(filepath.Join(modelDir, "adapter_config.json"), data, 0644)).To(Succeed())
|
||||
|
||||
cfg, err := importers.ImportLocalPath(modelDir, "desc-test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Description).To(ContainSubstring("TinyLlama/TinyLlama-1.1B"))
|
||||
Expect(cfg.Description).To(ContainSubstring("Fine-tuned from"))
|
||||
})
|
||||
})
|
||||
})
|
||||
266
core/http/endpoints/localai/finetune.go
Normal file
266
core/http/endpoints/localai/finetune.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// StartFineTuneJobEndpoint starts a new fine-tuning job.
|
||||
func StartFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
|
||||
var req schema.FineTuneJobRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "Invalid request: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "model is required",
|
||||
})
|
||||
}
|
||||
if req.DatasetSource == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "dataset_source is required",
|
||||
})
|
||||
}
|
||||
|
||||
resp, err := ftService.StartJob(c.Request().Context(), userID, req)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusCreated, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// ListFineTuneJobsEndpoint lists fine-tuning jobs for the current user.
|
||||
func ListFineTuneJobsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobs := ftService.ListJobs(userID)
|
||||
if jobs == nil {
|
||||
jobs = []*schema.FineTuneJob{}
|
||||
}
|
||||
return c.JSON(http.StatusOK, jobs)
|
||||
}
|
||||
}
|
||||
|
||||
// GetFineTuneJobEndpoint gets a specific fine-tuning job.
|
||||
func GetFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
job, err := ftService.GetJob(userID, jobID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, job)
|
||||
}
|
||||
}
|
||||
|
||||
// StopFineTuneJobEndpoint stops a running fine-tuning job.
|
||||
func StopFineTuneJobEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
// Check for save_checkpoint query param
|
||||
saveCheckpoint := c.QueryParam("save_checkpoint") == "true"
|
||||
|
||||
err := ftService.StopJob(c.Request().Context(), userID, jobID, saveCheckpoint)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]string{
|
||||
"status": "stopped",
|
||||
"message": "Fine-tuning job stopped",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// FineTuneProgressEndpoint streams progress updates via SSE.
|
||||
func FineTuneProgressEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
// Set SSE headers
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
|
||||
err := ftService.StreamProgress(c.Request().Context(), userID, jobID, func(event *schema.FineTuneProgressEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(c.Response(), "data: %s\n\n", data)
|
||||
c.Response().Flush()
|
||||
})
|
||||
if err != nil {
|
||||
// If headers already sent, we can't send a JSON error
|
||||
fmt.Fprintf(c.Response(), "data: {\"status\":\"error\",\"message\":%q}\n\n", err.Error())
|
||||
c.Response().Flush()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ListCheckpointsEndpoint lists checkpoints for a job.
|
||||
func ListCheckpointsEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
checkpoints, err := ftService.ListCheckpoints(c.Request().Context(), userID, jobID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusNotFound, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"checkpoints": checkpoints,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ExportModelEndpoint exports a model from a checkpoint.
|
||||
func ExportModelEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
userID := getUserID(c)
|
||||
jobID := c.Param("id")
|
||||
|
||||
var req schema.ExportRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "Invalid request: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
modelName, err := ftService.ExportModel(c.Request().Context(), userID, jobID, req)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusAccepted, map[string]string{
|
||||
"status": "exporting",
|
||||
"message": "Export started for model '" + modelName + "'",
|
||||
"model_name": modelName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ListFineTuneBackendsEndpoint returns installed backends tagged with "fine-tuning".
|
||||
func ListFineTuneBackendsEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": "failed to list backends: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
type backendInfo struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
var result []backendInfo
|
||||
for _, b := range backends {
|
||||
if !b.Installed {
|
||||
continue
|
||||
}
|
||||
hasTag := false
|
||||
for _, t := range b.Tags {
|
||||
if strings.EqualFold(t, "fine-tuning") {
|
||||
hasTag = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasTag {
|
||||
continue
|
||||
}
|
||||
name := b.Name
|
||||
if b.Alias != "" {
|
||||
name = b.Alias
|
||||
}
|
||||
result = append(result, backendInfo{
|
||||
Name: name,
|
||||
Description: b.Description,
|
||||
Tags: b.Tags,
|
||||
})
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
result = []backendInfo{}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
|
||||
// UploadDatasetEndpoint handles dataset file upload.
|
||||
func UploadDatasetEndpoint(ftService *services.FineTuneService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{
|
||||
"error": "file is required",
|
||||
})
|
||||
}
|
||||
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": "failed to open file",
|
||||
})
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
data, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": "failed to read file",
|
||||
})
|
||||
}
|
||||
|
||||
path, err := ftService.UploadDataset(file.Filename, data)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]string{
|
||||
"path": path,
|
||||
})
|
||||
}
|
||||
}
|
||||
1294
core/http/react-ui/src/pages/FineTune.jsx
Normal file
1294
core/http/react-ui/src/pages/FineTune.jsx
Normal file
File diff suppressed because it is too large
Load Diff
40
core/http/routes/finetuning.go
Normal file
40
core/http/routes/finetuning.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// RegisterFineTuningRoutes registers fine-tuning API routes.
|
||||
func RegisterFineTuningRoutes(e *echo.Echo, ftService *services.FineTuneService, appConfig *config.ApplicationConfig, fineTuningMw echo.MiddlewareFunc) {
|
||||
if ftService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Service readiness middleware
|
||||
readyMw := func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if ftService == nil {
|
||||
return c.JSON(http.StatusServiceUnavailable, map[string]string{
|
||||
"error": "fine-tuning service is not available",
|
||||
})
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
|
||||
ft := e.Group("/api/fine-tuning", readyMw, fineTuningMw)
|
||||
ft.GET("/backends", localai.ListFineTuneBackendsEndpoint(appConfig))
|
||||
ft.POST("/jobs", localai.StartFineTuneJobEndpoint(ftService))
|
||||
ft.GET("/jobs", localai.ListFineTuneJobsEndpoint(ftService))
|
||||
ft.GET("/jobs/:id", localai.GetFineTuneJobEndpoint(ftService))
|
||||
ft.DELETE("/jobs/:id", localai.StopFineTuneJobEndpoint(ftService))
|
||||
ft.GET("/jobs/:id/progress", localai.FineTuneProgressEndpoint(ftService))
|
||||
ft.GET("/jobs/:id/checkpoints", localai.ListCheckpointsEndpoint(ftService))
|
||||
ft.POST("/jobs/:id/export", localai.ExportModelEndpoint(ftService))
|
||||
ft.POST("/datasets", localai.UploadDatasetEndpoint(ftService))
|
||||
}
|
||||
99
core/schema/finetune.go
Normal file
99
core/schema/finetune.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package schema
|
||||
|
||||
// FineTuneJobRequest is the REST API request to start a fine-tuning job.
|
||||
type FineTuneJobRequest struct {
|
||||
Model string `json:"model"`
|
||||
Backend string `json:"backend"` // "unsloth", "trl"
|
||||
TrainingType string `json:"training_type,omitempty"` // lora, loha, lokr, full
|
||||
TrainingMethod string `json:"training_method,omitempty"` // sft, dpo, grpo, rloo, reward, kto, orpo
|
||||
|
||||
// Adapter config
|
||||
AdapterRank int32 `json:"adapter_rank,omitempty"`
|
||||
AdapterAlpha int32 `json:"adapter_alpha,omitempty"`
|
||||
AdapterDropout float32 `json:"adapter_dropout,omitempty"`
|
||||
TargetModules []string `json:"target_modules,omitempty"`
|
||||
|
||||
// Training hyperparameters
|
||||
LearningRate float32 `json:"learning_rate,omitempty"`
|
||||
NumEpochs int32 `json:"num_epochs,omitempty"`
|
||||
BatchSize int32 `json:"batch_size,omitempty"`
|
||||
GradientAccumulationSteps int32 `json:"gradient_accumulation_steps,omitempty"`
|
||||
WarmupSteps int32 `json:"warmup_steps,omitempty"`
|
||||
MaxSteps int32 `json:"max_steps,omitempty"`
|
||||
SaveSteps int32 `json:"save_steps,omitempty"`
|
||||
WeightDecay float32 `json:"weight_decay,omitempty"`
|
||||
GradientCheckpointing bool `json:"gradient_checkpointing,omitempty"`
|
||||
Optimizer string `json:"optimizer,omitempty"`
|
||||
Seed int32 `json:"seed,omitempty"`
|
||||
MixedPrecision string `json:"mixed_precision,omitempty"`
|
||||
|
||||
// Dataset
|
||||
DatasetSource string `json:"dataset_source"`
|
||||
DatasetSplit string `json:"dataset_split,omitempty"`
|
||||
|
||||
// Resume from a checkpoint
|
||||
ResumeFromCheckpoint string `json:"resume_from_checkpoint,omitempty"`
|
||||
|
||||
// Backend-specific and method-specific options
|
||||
ExtraOptions map[string]string `json:"extra_options,omitempty"`
|
||||
}
|
||||
|
||||
// FineTuneJob represents a fine-tuning job with its current state.
|
||||
type FineTuneJob struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Backend string `json:"backend"`
|
||||
TrainingType string `json:"training_type"`
|
||||
TrainingMethod string `json:"training_method"`
|
||||
Status string `json:"status"` // queued, loading_model, loading_dataset, training, saving, completed, failed, stopped
|
||||
Message string `json:"message,omitempty"`
|
||||
OutputDir string `json:"output_dir"`
|
||||
ExtraOptions map[string]string `json:"extra_options,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
|
||||
// Export state (tracked separately from training status)
|
||||
ExportStatus string `json:"export_status,omitempty"` // "", "exporting", "completed", "failed"
|
||||
ExportMessage string `json:"export_message,omitempty"`
|
||||
ExportModelName string `json:"export_model_name,omitempty"` // registered model name after export
|
||||
|
||||
// Full config for resume/reuse
|
||||
Config *FineTuneJobRequest `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
// FineTuneJobResponse is the REST API response when creating a job.
|
||||
type FineTuneJobResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// FineTuneProgressEvent is an SSE event for training progress.
|
||||
type FineTuneProgressEvent struct {
|
||||
JobID string `json:"job_id"`
|
||||
CurrentStep int32 `json:"current_step"`
|
||||
TotalSteps int32 `json:"total_steps"`
|
||||
CurrentEpoch float32 `json:"current_epoch"`
|
||||
TotalEpochs float32 `json:"total_epochs"`
|
||||
Loss float32 `json:"loss"`
|
||||
LearningRate float32 `json:"learning_rate"`
|
||||
GradNorm float32 `json:"grad_norm"`
|
||||
EvalLoss float32 `json:"eval_loss"`
|
||||
EtaSeconds float32 `json:"eta_seconds"`
|
||||
ProgressPercent float32 `json:"progress_percent"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
CheckpointPath string `json:"checkpoint_path,omitempty"`
|
||||
SamplePath string `json:"sample_path,omitempty"`
|
||||
ExtraMetrics map[string]float32 `json:"extra_metrics,omitempty"`
|
||||
}
|
||||
|
||||
// ExportRequest is the REST API request to export a model.
|
||||
type ExportRequest struct {
|
||||
Name string `json:"name,omitempty"` // model name for LocalAI (auto-generated if empty)
|
||||
CheckpointPath string `json:"checkpoint_path"`
|
||||
ExportFormat string `json:"export_format"` // lora, merged_16bit, merged_4bit, gguf
|
||||
QuantizationMethod string `json:"quantization_method"` // for GGUF: q4_k_m, q5_k_m, q8_0, f16
|
||||
Model string `json:"model,omitempty"` // base model name for merge
|
||||
ExtraOptions map[string]string `json:"extra_options,omitempty"`
|
||||
}
|
||||
564
core/services/finetune.go
Normal file
564
core/services/finetune.go
Normal file
@@ -0,0 +1,564 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// FineTuneService manages fine-tuning jobs and their lifecycle.
|
||||
type FineTuneService struct {
|
||||
appConfig *config.ApplicationConfig
|
||||
modelLoader *model.ModelLoader
|
||||
configLoader *config.ModelConfigLoader
|
||||
|
||||
mu sync.Mutex
|
||||
jobs map[string]*schema.FineTuneJob
|
||||
}
|
||||
|
||||
// NewFineTuneService creates a new FineTuneService.
|
||||
func NewFineTuneService(
|
||||
appConfig *config.ApplicationConfig,
|
||||
modelLoader *model.ModelLoader,
|
||||
configLoader *config.ModelConfigLoader,
|
||||
) *FineTuneService {
|
||||
s := &FineTuneService{
|
||||
appConfig: appConfig,
|
||||
modelLoader: modelLoader,
|
||||
configLoader: configLoader,
|
||||
jobs: make(map[string]*schema.FineTuneJob),
|
||||
}
|
||||
s.loadAllJobs()
|
||||
return s
|
||||
}
|
||||
|
||||
// fineTuneBaseDir returns the base directory for fine-tune job data.
|
||||
func (s *FineTuneService) fineTuneBaseDir() string {
|
||||
return filepath.Join(s.appConfig.DataPath, "fine-tune")
|
||||
}
|
||||
|
||||
// jobDir returns the directory for a specific job.
|
||||
func (s *FineTuneService) jobDir(jobID string) string {
|
||||
return filepath.Join(s.fineTuneBaseDir(), jobID)
|
||||
}
|
||||
|
||||
// saveJobState persists a job's state to disk as state.json.
|
||||
func (s *FineTuneService) saveJobState(job *schema.FineTuneJob) {
|
||||
dir := s.jobDir(job.ID)
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
xlog.Error("Failed to create job directory", "job_id", job.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(job, "", " ")
|
||||
if err != nil {
|
||||
xlog.Error("Failed to marshal job state", "job_id", job.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
statePath := filepath.Join(dir, "state.json")
|
||||
if err := os.WriteFile(statePath, data, 0640); err != nil {
|
||||
xlog.Error("Failed to write job state", "job_id", job.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// loadAllJobs scans the fine-tune directory for persisted jobs and loads them.
|
||||
func (s *FineTuneService) loadAllJobs() {
|
||||
baseDir := s.fineTuneBaseDir()
|
||||
entries, err := os.ReadDir(baseDir)
|
||||
if err != nil {
|
||||
// Directory doesn't exist yet — that's fine
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
statePath := filepath.Join(baseDir, entry.Name(), "state.json")
|
||||
data, err := os.ReadFile(statePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var job schema.FineTuneJob
|
||||
if err := json.Unmarshal(data, &job); err != nil {
|
||||
xlog.Warn("Failed to parse job state", "path", statePath, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Jobs that were running when we shut down are now stale
|
||||
if job.Status == "queued" || job.Status == "loading_model" || job.Status == "loading_dataset" || job.Status == "training" || job.Status == "saving" {
|
||||
job.Status = "stopped"
|
||||
job.Message = "Server restarted while job was running"
|
||||
}
|
||||
|
||||
// Exports that were in progress are now stale
|
||||
if job.ExportStatus == "exporting" {
|
||||
job.ExportStatus = "failed"
|
||||
job.ExportMessage = "Server restarted while export was running"
|
||||
}
|
||||
|
||||
s.jobs[job.ID] = &job
|
||||
}
|
||||
|
||||
if len(s.jobs) > 0 {
|
||||
xlog.Info("Loaded persisted fine-tune jobs", "count", len(s.jobs))
|
||||
}
|
||||
}
|
||||
|
||||
// StartJob starts a new fine-tuning job.
|
||||
func (s *FineTuneService) StartJob(ctx context.Context, userID string, req schema.FineTuneJobRequest) (*schema.FineTuneJobResponse, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
jobID := uuid.New().String()
|
||||
|
||||
backendName := req.Backend
|
||||
if backendName == "" {
|
||||
backendName = "trl"
|
||||
}
|
||||
|
||||
// Always use DataPath for output — not user-configurable
|
||||
outputDir := filepath.Join(s.fineTuneBaseDir(), jobID)
|
||||
|
||||
// Build gRPC request
|
||||
grpcReq := &pb.FineTuneRequest{
|
||||
Model: req.Model,
|
||||
TrainingType: req.TrainingType,
|
||||
TrainingMethod: req.TrainingMethod,
|
||||
AdapterRank: req.AdapterRank,
|
||||
AdapterAlpha: req.AdapterAlpha,
|
||||
AdapterDropout: req.AdapterDropout,
|
||||
TargetModules: req.TargetModules,
|
||||
LearningRate: req.LearningRate,
|
||||
NumEpochs: req.NumEpochs,
|
||||
BatchSize: req.BatchSize,
|
||||
GradientAccumulationSteps: req.GradientAccumulationSteps,
|
||||
WarmupSteps: req.WarmupSteps,
|
||||
MaxSteps: req.MaxSteps,
|
||||
SaveSteps: req.SaveSteps,
|
||||
WeightDecay: req.WeightDecay,
|
||||
GradientCheckpointing: req.GradientCheckpointing,
|
||||
Optimizer: req.Optimizer,
|
||||
Seed: req.Seed,
|
||||
MixedPrecision: req.MixedPrecision,
|
||||
DatasetSource: req.DatasetSource,
|
||||
DatasetSplit: req.DatasetSplit,
|
||||
OutputDir: outputDir,
|
||||
JobId: jobID,
|
||||
ResumeFromCheckpoint: req.ResumeFromCheckpoint,
|
||||
ExtraOptions: req.ExtraOptions,
|
||||
}
|
||||
|
||||
// Load the fine-tuning backend
|
||||
backendModel, err := s.modelLoader.Load(
|
||||
model.WithBackendString(backendName),
|
||||
model.WithModel(backendName),
|
||||
model.WithModelID(backendName+"-finetune"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load backend %s: %w", backendName, err)
|
||||
}
|
||||
|
||||
// Start fine-tuning via gRPC
|
||||
result, err := backendModel.StartFineTune(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start fine-tuning: %w", err)
|
||||
}
|
||||
if !result.Success {
|
||||
return nil, fmt.Errorf("fine-tuning failed to start: %s", result.Message)
|
||||
}
|
||||
|
||||
// Track the job
|
||||
job := &schema.FineTuneJob{
|
||||
ID: jobID,
|
||||
UserID: userID,
|
||||
Model: req.Model,
|
||||
Backend: backendName,
|
||||
TrainingType: req.TrainingType,
|
||||
TrainingMethod: req.TrainingMethod,
|
||||
Status: "queued",
|
||||
OutputDir: outputDir,
|
||||
ExtraOptions: req.ExtraOptions,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
Config: &req,
|
||||
}
|
||||
s.jobs[jobID] = job
|
||||
s.saveJobState(job)
|
||||
|
||||
return &schema.FineTuneJobResponse{
|
||||
ID: jobID,
|
||||
Status: "queued",
|
||||
Message: result.Message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetJob returns a fine-tuning job by ID.
|
||||
func (s *FineTuneService) GetJob(userID, jobID string) (*schema.FineTuneJob, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// ListJobs returns all jobs for a user.
|
||||
func (s *FineTuneService) ListJobs(userID string) []*schema.FineTuneJob {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var result []*schema.FineTuneJob
|
||||
for _, job := range s.jobs {
|
||||
if userID == "" || job.UserID == userID {
|
||||
result = append(result, job)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// StopJob stops a running fine-tuning job.
|
||||
func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, saveCheckpoint bool) error {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
backendModel, err := s.modelLoader.Load(
|
||||
model.WithBackendString(job.Backend),
|
||||
model.WithModel(job.Backend),
|
||||
model.WithModelID(job.Backend+"-finetune"),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load backend: %w", err)
|
||||
}
|
||||
|
||||
_, err = backendModel.StopFineTune(ctx, &pb.FineTuneStopRequest{
|
||||
JobId: jobID,
|
||||
SaveCheckpoint: saveCheckpoint,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stop job: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
job.Status = "stopped"
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StreamProgress opens a gRPC progress stream and calls the callback for each update.
|
||||
func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.FineTuneProgressEvent)) error {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
backendModel, err := s.modelLoader.Load(
|
||||
model.WithBackendString(job.Backend),
|
||||
model.WithModel(job.Backend),
|
||||
model.WithModelID(job.Backend+"-finetune"),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load backend: %w", err)
|
||||
}
|
||||
|
||||
return backendModel.FineTuneProgress(ctx, &pb.FineTuneProgressRequest{
|
||||
JobId: jobID,
|
||||
}, func(update *pb.FineTuneProgressUpdate) {
|
||||
// Update job status and persist
|
||||
s.mu.Lock()
|
||||
if j, ok := s.jobs[jobID]; ok {
|
||||
j.Status = update.Status
|
||||
if update.Message != "" {
|
||||
j.Message = update.Message
|
||||
}
|
||||
s.saveJobState(j)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Convert extra metrics
|
||||
extraMetrics := make(map[string]float32)
|
||||
for k, v := range update.ExtraMetrics {
|
||||
extraMetrics[k] = v
|
||||
}
|
||||
|
||||
event := &schema.FineTuneProgressEvent{
|
||||
JobID: update.JobId,
|
||||
CurrentStep: update.CurrentStep,
|
||||
TotalSteps: update.TotalSteps,
|
||||
CurrentEpoch: update.CurrentEpoch,
|
||||
TotalEpochs: update.TotalEpochs,
|
||||
Loss: update.Loss,
|
||||
LearningRate: update.LearningRate,
|
||||
GradNorm: update.GradNorm,
|
||||
EvalLoss: update.EvalLoss,
|
||||
EtaSeconds: update.EtaSeconds,
|
||||
ProgressPercent: update.ProgressPercent,
|
||||
Status: update.Status,
|
||||
Message: update.Message,
|
||||
CheckpointPath: update.CheckpointPath,
|
||||
SamplePath: update.SamplePath,
|
||||
ExtraMetrics: extraMetrics,
|
||||
}
|
||||
callback(event)
|
||||
})
|
||||
}
|
||||
|
||||
// ListCheckpoints lists checkpoints for a job.
|
||||
func (s *FineTuneService) ListCheckpoints(ctx context.Context, userID, jobID string) ([]*pb.CheckpointInfo, error) {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
backendModel, err := s.modelLoader.Load(
|
||||
model.WithBackendString(job.Backend),
|
||||
model.WithModel(job.Backend),
|
||||
model.WithModelID(job.Backend+"-finetune"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load backend: %w", err)
|
||||
}
|
||||
|
||||
resp, err := backendModel.ListCheckpoints(ctx, &pb.ListCheckpointsRequest{
|
||||
OutputDir: job.OutputDir,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list checkpoints: %w", err)
|
||||
}
|
||||
|
||||
return resp.Checkpoints, nil
|
||||
}
|
||||
|
||||
// sanitizeModelName replaces non-alphanumeric characters with hyphens and lowercases.
|
||||
func sanitizeModelName(s string) string {
|
||||
re := regexp.MustCompile(`[^a-zA-Z0-9\-]`)
|
||||
s = re.ReplaceAllString(s, "-")
|
||||
s = regexp.MustCompile(`-+`).ReplaceAllString(s, "-")
|
||||
s = strings.Trim(s, "-")
|
||||
return strings.ToLower(s)
|
||||
}
|
||||
|
||||
// ExportModel starts an async model export from a checkpoint and returns the intended model name immediately.
|
||||
func (s *FineTuneService) ExportModel(ctx context.Context, userID, jobID string, req schema.ExportRequest) (string, error) {
|
||||
s.mu.Lock()
|
||||
job, ok := s.jobs[jobID]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if userID != "" && job.UserID != userID {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("job not found: %s", jobID)
|
||||
}
|
||||
if job.ExportStatus == "exporting" {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("export already in progress for job %s", jobID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Compute model name
|
||||
modelName := req.Name
|
||||
if modelName == "" {
|
||||
base := sanitizeModelName(job.Model)
|
||||
if base == "" {
|
||||
base = "model"
|
||||
}
|
||||
shortID := jobID
|
||||
if len(shortID) > 8 {
|
||||
shortID = shortID[:8]
|
||||
}
|
||||
modelName = base + "-ft-" + shortID
|
||||
}
|
||||
|
||||
// Compute output path in models directory
|
||||
modelsPath := s.appConfig.SystemState.Model.ModelsPath
|
||||
outputPath := filepath.Join(modelsPath, modelName)
|
||||
|
||||
// Check for name collision (synchronous — fast validation)
|
||||
configPath := filepath.Join(modelsPath, modelName+".yaml")
|
||||
if err := utils.VerifyPath(modelName+".yaml", modelsPath); err != nil {
|
||||
return "", fmt.Errorf("invalid model name: %w", err)
|
||||
}
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
return "", fmt.Errorf("model %q already exists, choose a different name", modelName)
|
||||
}
|
||||
|
||||
// Create output directory
|
||||
if err := os.MkdirAll(outputPath, 0750); err != nil {
|
||||
return "", fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
|
||||
// Set export status to "exporting" and persist
|
||||
s.mu.Lock()
|
||||
job.ExportStatus = "exporting"
|
||||
job.ExportMessage = ""
|
||||
job.ExportModelName = ""
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Launch the export in a background goroutine
|
||||
go func() {
|
||||
backendModel, err := s.modelLoader.Load(
|
||||
model.WithBackendString(job.Backend),
|
||||
model.WithModel(job.Backend),
|
||||
model.WithModelID(job.Backend+"-finetune"),
|
||||
)
|
||||
if err != nil {
|
||||
s.setExportFailed(job, fmt.Sprintf("failed to load backend: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Merge job's extra_options (contains hf_token from training) with request's
|
||||
mergedOpts := make(map[string]string)
|
||||
for k, v := range job.ExtraOptions {
|
||||
mergedOpts[k] = v
|
||||
}
|
||||
for k, v := range req.ExtraOptions {
|
||||
mergedOpts[k] = v // request overrides job
|
||||
}
|
||||
|
||||
grpcReq := &pb.ExportModelRequest{
|
||||
CheckpointPath: req.CheckpointPath,
|
||||
OutputPath: outputPath,
|
||||
ExportFormat: req.ExportFormat,
|
||||
QuantizationMethod: req.QuantizationMethod,
|
||||
Model: req.Model,
|
||||
ExtraOptions: mergedOpts,
|
||||
}
|
||||
|
||||
result, err := backendModel.ExportModel(context.Background(), grpcReq)
|
||||
if err != nil {
|
||||
s.setExportFailed(job, fmt.Sprintf("export failed: %v", err))
|
||||
return
|
||||
}
|
||||
if !result.Success {
|
||||
s.setExportFailed(job, fmt.Sprintf("export failed: %s", result.Message))
|
||||
return
|
||||
}
|
||||
|
||||
// Auto-import: detect format and generate config
|
||||
cfg, err := importers.ImportLocalPath(outputPath, modelName)
|
||||
if err != nil {
|
||||
s.setExportFailed(job, fmt.Sprintf("model exported to %s but config generation failed: %v", outputPath, err))
|
||||
return
|
||||
}
|
||||
|
||||
cfg.Name = modelName
|
||||
|
||||
// If base model not detected from files, use the job's model field
|
||||
if cfg.Model == "" && job.Model != "" {
|
||||
cfg.Model = job.Model
|
||||
}
|
||||
|
||||
// Write YAML config
|
||||
yamlData, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
s.setExportFailed(job, fmt.Sprintf("failed to marshal config: %v", err))
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
|
||||
s.setExportFailed(job, fmt.Sprintf("failed to write config file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Reload configs so the model is immediately available
|
||||
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
xlog.Warn("Failed to reload configs after export", "error", err)
|
||||
}
|
||||
if err := s.configLoader.Preload(modelsPath); err != nil {
|
||||
xlog.Warn("Failed to preload after export", "error", err)
|
||||
}
|
||||
|
||||
xlog.Info("Model exported and registered", "job_id", jobID, "model_name", modelName, "format", req.ExportFormat)
|
||||
|
||||
s.mu.Lock()
|
||||
job.ExportStatus = "completed"
|
||||
job.ExportModelName = modelName
|
||||
job.ExportMessage = ""
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
return modelName, nil
|
||||
}
|
||||
|
||||
// setExportFailed sets the export status to failed with a message.
|
||||
func (s *FineTuneService) setExportFailed(job *schema.FineTuneJob, message string) {
|
||||
xlog.Error("Export failed", "job_id", job.ID, "error", message)
|
||||
s.mu.Lock()
|
||||
job.ExportStatus = "failed"
|
||||
job.ExportMessage = message
|
||||
s.saveJobState(job)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// UploadDataset handles dataset file upload and returns the local path.
|
||||
func (s *FineTuneService) UploadDataset(filename string, data []byte) (string, error) {
|
||||
uploadDir := filepath.Join(s.fineTuneBaseDir(), "datasets")
|
||||
if err := os.MkdirAll(uploadDir, 0750); err != nil {
|
||||
return "", fmt.Errorf("failed to create dataset directory: %w", err)
|
||||
}
|
||||
|
||||
filePath := filepath.Join(uploadDir, uuid.New().String()[:8]+"-"+filename)
|
||||
if err := os.WriteFile(filePath, data, 0640); err != nil {
|
||||
return "", fmt.Errorf("failed to write dataset: %w", err)
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// MarshalProgressEvent converts a progress event to JSON for SSE.
|
||||
func MarshalProgressEvent(event *schema.FineTuneProgressEvent) (string, error) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
185
docs/content/features/fine-tuning.md
Normal file
185
docs/content/features/fine-tuning.md
Normal file
@@ -0,0 +1,185 @@
|
||||
+++
|
||||
disableToc = false
|
||||
title = "Fine-Tuning"
|
||||
weight = 18
|
||||
url = '/features/fine-tuning/'
|
||||
+++
|
||||
|
||||
LocalAI supports fine-tuning LLMs directly through the API and Web UI. Fine-tuning is powered by pluggable backends that implement a generic gRPC interface, allowing support for different training frameworks and model types.
|
||||
|
||||
## Supported Backends
|
||||
|
||||
| Backend | Domain | GPU Required | Training Methods | Adapter Types |
|
||||
|---------|--------|-------------|-----------------|---------------|
|
||||
| **unsloth** | LLM fine-tuning | Yes (CUDA) | SFT, GRPO | LoRA/QLoRA |
|
||||
| **trl** | LLM fine-tuning | No (CPU or GPU) | SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO | LoRA, Full |
|
||||
|
||||
## Enabling Fine-Tuning
|
||||
|
||||
Fine-tuning is disabled by default. Enable it with:
|
||||
|
||||
```bash
|
||||
LOCALAI_ENABLE_FINETUNING=true local-ai
|
||||
```
|
||||
|
||||
When authentication is enabled, fine-tuning is a per-user feature (default OFF). Admins can enable it for specific users via the user management API.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Start a fine-tuning job
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/fine-tuning/jobs \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "unsloth/tinyllama-bnb-4bit",
|
||||
"backend": "unsloth",
|
||||
"training_method": "sft",
|
||||
"training_type": "lora",
|
||||
"dataset_source": "yahma/alpaca-cleaned",
|
||||
"num_epochs": 1,
|
||||
"batch_size": 2,
|
||||
"learning_rate": 0.0002,
|
||||
"adapter_rank": 16,
|
||||
"adapter_alpha": 16,
|
||||
"extra_options": {
|
||||
"max_seq_length": "2048",
|
||||
"load_in_4bit": "true"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. Monitor progress (SSE stream)
|
||||
|
||||
```bash
|
||||
curl -N http://localhost:8080/api/fine-tuning/jobs/{job_id}/progress
|
||||
```
|
||||
|
||||
### 3. List checkpoints
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/api/fine-tuning/jobs/{job_id}/checkpoints
|
||||
```
|
||||
|
||||
### 4. Export model
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/fine-tuning/jobs/{job_id}/export \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"export_format": "gguf",
|
||||
"quantization_method": "q4_k_m",
|
||||
"output_path": "/models/my-finetuned-model"
|
||||
}'
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| `POST` | `/api/fine-tuning/jobs` | Start a fine-tuning job |
|
||||
| `GET` | `/api/fine-tuning/jobs` | List all jobs |
|
||||
| `GET` | `/api/fine-tuning/jobs/:id` | Get job details |
|
||||
| `DELETE` | `/api/fine-tuning/jobs/:id` | Stop a running job |
|
||||
| `GET` | `/api/fine-tuning/jobs/:id/progress` | SSE progress stream |
|
||||
| `GET` | `/api/fine-tuning/jobs/:id/checkpoints` | List checkpoints |
|
||||
| `POST` | `/api/fine-tuning/jobs/:id/export` | Export model |
|
||||
| `POST` | `/api/fine-tuning/datasets` | Upload dataset file |
|
||||
|
||||
### Job Request Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `model` | string | HuggingFace model ID or local path (required) |
|
||||
| `backend` | string | Backend name: `unsloth` or `trl` (default: `trl`) |
|
||||
| `training_method` | string | `sft`, `dpo`, `grpo`, `rloo`, `reward`, `kto`, `orpo` |
|
||||
| `training_type` | string | `lora` or `full` |
|
||||
| `dataset_source` | string | HuggingFace dataset ID or local file path (required) |
|
||||
| `adapter_rank` | int | LoRA rank (default: 16) |
|
||||
| `adapter_alpha` | int | LoRA alpha (default: 16) |
|
||||
| `num_epochs` | int | Number of training epochs (default: 3) |
|
||||
| `batch_size` | int | Per-device batch size (default: 2) |
|
||||
| `learning_rate` | float | Learning rate (default: 2e-4) |
|
||||
| `gradient_accumulation_steps` | int | Gradient accumulation (default: 4) |
|
||||
| `warmup_steps` | int | Warmup steps (default: 5) |
|
||||
| `optimizer` | string | `adamw_torch`, `adamw_8bit`, `sgd`, `adafactor`, `prodigy` |
|
||||
| `extra_options` | map | Backend-specific options (see below) |
|
||||
|
||||
### Backend-Specific Options (`extra_options`)
|
||||
|
||||
#### Unsloth
|
||||
|
||||
| Key | Description | Default |
|
||||
|-----|-------------|---------|
|
||||
| `max_seq_length` | Maximum sequence length | `2048` |
|
||||
| `load_in_4bit` | Load model in 4-bit quantization | `true` |
|
||||
| `packing` | Enable sequence packing | `false` |
|
||||
| `use_rslora` | Use Rank-Stabilized LoRA | `false` |
|
||||
|
||||
#### TRL
|
||||
|
||||
| Key | Description | Default |
|
||||
|-----|-------------|---------|
|
||||
| `max_seq_length` | Maximum sequence length | `512` |
|
||||
| `packing` | Enable sequence packing | `false` |
|
||||
| `trust_remote_code` | Trust remote code in model | `false` |
|
||||
| `load_in_4bit` | Enable 4-bit quantization (GPU only) | `false` |
|
||||
|
||||
#### DPO-specific (training_method=dpo)
|
||||
|
||||
| Key | Description | Default |
|
||||
|-----|-------------|---------|
|
||||
| `beta` | KL penalty coefficient | `0.1` |
|
||||
| `loss_type` | Loss type: `sigmoid`, `hinge`, `ipo` | `sigmoid` |
|
||||
| `max_length` | Maximum sequence length | `512` |
|
||||
|
||||
#### GRPO-specific (training_method=grpo)
|
||||
|
||||
| Key | Description | Default |
|
||||
|-----|-------------|---------|
|
||||
| `num_generations` | Number of generations per prompt | `4` |
|
||||
| `max_completion_length` | Max completion token length | `256` |
|
||||
|
||||
### Export Formats
|
||||
|
||||
| Format | Description | Notes |
|
||||
|--------|-------------|-------|
|
||||
| `lora` | LoRA adapter files | Smallest, requires base model |
|
||||
| `merged_16bit` | Full model in 16-bit | Large but standalone |
|
||||
| `merged_4bit` | Full model in 4-bit | Smaller, standalone |
|
||||
| `gguf` | GGUF format | For llama.cpp, requires `quantization_method` |
|
||||
|
||||
### GGUF Quantization Methods
|
||||
|
||||
`q4_k_m`, `q5_k_m`, `q8_0`, `f16`, `q4_0`, `q5_0`
|
||||
|
||||
## Web UI
|
||||
|
||||
When fine-tuning is enabled, a "Fine-Tune" page appears in the sidebar under the Agents section. The UI provides:
|
||||
|
||||
1. **Job Configuration** — Select backend, model, training method, adapter type, and hyperparameters
|
||||
2. **Dataset Upload** — Upload local datasets or reference HuggingFace datasets
|
||||
3. **Training Monitor** — Real-time loss chart, progress bar, metrics display
|
||||
4. **Export** — Export trained models in various formats
|
||||
|
||||
## Dataset Formats
|
||||
|
||||
Datasets should follow standard HuggingFace formats:
|
||||
|
||||
- **SFT**: Alpaca format (`instruction`, `input`, `output` fields) or ChatML/ShareGPT
|
||||
- **DPO**: Preference pairs (`prompt`, `chosen`, `rejected` fields)
|
||||
- **GRPO**: Prompts with reward signals
|
||||
|
||||
Supported file formats: `.json`, `.jsonl`, `.csv`
|
||||
|
||||
## Architecture
|
||||
|
||||
Fine-tuning uses the same gRPC backend architecture as inference:
|
||||
|
||||
1. **Proto layer**: `FineTuneRequest`, `FineTuneProgress` (streaming), `StopFineTune`, `ListCheckpoints`, `ExportModel`
|
||||
2. **Python backends**: Each backend implements the gRPC interface with its specific training framework
|
||||
3. **Go service**: Manages job lifecycle, routes API requests to backends
|
||||
4. **REST API**: HTTP endpoints with SSE progress streaming
|
||||
5. **React UI**: Configuration form, real-time training monitor, export panel
|
||||
Reference in New Issue
Block a user