From 7050c9f69dd243bb5ccaeb1d89a6de9fc13c4740 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 14 Aug 2025 23:48:09 +0200 Subject: [PATCH] feat(webui): add import/edit model page (#6050) * feat(webui): add import/edit model page Signed-off-by: Ettore Di Giacinto * Convert to a YAML editor Signed-off-by: Ettore Di Giacinto * Pass by the baseurl Signed-off-by: Ettore Di Giacinto * Fixups Signed-off-by: Ettore Di Giacinto * Add tests Signed-off-by: Ettore Di Giacinto * Simplify Signed-off-by: Ettore Di Giacinto * Improve visibility of the yaml editor Signed-off-by: Ettore Di Giacinto * Add test file Signed-off-by: Ettore Di Giacinto * Make reset work Signed-off-by: Ettore Di Giacinto * Emit error only if we can't delete the model yaml file Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- core/config/backend_config.go | 237 ++-- core/gallery/models.go | 18 +- core/http/endpoints/localai/edit_model.go | 228 ++++ .../http/endpoints/localai/edit_model_test.go | 72 ++ core/http/endpoints/localai/import_model.go | 148 +++ .../endpoints/localai/localai_suite_test.go | 13 + core/http/endpoints/localai/types.go | 11 + core/http/routes/localai.go | 17 + core/http/views/index.html | 18 +- core/http/views/model-editor.html | 1022 +++++++++++++++++ 10 files changed, 1655 insertions(+), 129 deletions(-) create mode 100644 core/http/endpoints/localai/edit_model.go create mode 100644 core/http/endpoints/localai/edit_model_test.go create mode 100644 core/http/endpoints/localai/import_model.go create mode 100644 core/http/endpoints/localai/localai_suite_test.go create mode 100644 core/http/endpoints/localai/types.go create mode 100644 core/http/views/model-editor.html diff --git a/core/config/backend_config.go b/core/config/backend_config.go index e39e828fa..123462fc6 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -19,67 +19,68 @@ const ( type TTSConfig struct { // Voice wav path or id - Voice string `yaml:"voice"` + Voice string `yaml:"voice" json:"voice"` - AudioPath string `yaml:"audio_path"` + AudioPath string `yaml:"audio_path" json:"audio_path"` } +// ModelConfig represents a model configuration type ModelConfig struct { - schema.PredictionOptions `yaml:"parameters"` - Name string `yaml:"name"` + schema.PredictionOptions `yaml:"parameters" json:"parameters"` + Name string `yaml:"name" json:"name"` - F16 *bool `yaml:"f16"` - Threads *int `yaml:"threads"` - Debug *bool `yaml:"debug"` - Roles map[string]string `yaml:"roles"` - Embeddings *bool `yaml:"embeddings"` - Backend string `yaml:"backend"` - TemplateConfig TemplateConfig `yaml:"template"` - KnownUsecaseStrings []string `yaml:"known_usecases"` - KnownUsecases *ModelConfigUsecases `yaml:"-"` - Pipeline Pipeline `yaml:"pipeline"` + F16 *bool `yaml:"f16" json:"f16"` + Threads *int `yaml:"threads" json:"threads"` + Debug *bool `yaml:"debug" json:"debug"` + Roles map[string]string `yaml:"roles" json:"roles"` + Embeddings *bool `yaml:"embeddings" json:"embeddings"` + Backend string `yaml:"backend" json:"backend"` + TemplateConfig TemplateConfig `yaml:"template" json:"template"` + KnownUsecaseStrings []string `yaml:"known_usecases" json:"known_usecases"` + KnownUsecases *ModelConfigUsecases `yaml:"-" json:"-"` + Pipeline Pipeline `yaml:"pipeline" json:"pipeline"` - PromptStrings, InputStrings []string `yaml:"-"` - InputToken [][]int `yaml:"-"` - functionCallString, functionCallNameString string `yaml:"-"` - ResponseFormat string `yaml:"-"` - ResponseFormatMap map[string]interface{} `yaml:"-"` + PromptStrings, InputStrings []string `yaml:"-" json:"-"` + InputToken [][]int `yaml:"-" json:"-"` + functionCallString, functionCallNameString string `yaml:"-" json:"-"` + ResponseFormat string `yaml:"-" json:"-"` + ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"` - FunctionsConfig functions.FunctionsConfig `yaml:"function"` + FunctionsConfig functions.FunctionsConfig `yaml:"function" json:"function"` - FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early. + FeatureFlag FeatureFlag `yaml:"feature_flags" json:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early. // LLM configs (GPT4ALL, Llama.cpp, ...) - LLMConfig `yaml:",inline"` + LLMConfig `yaml:",inline" json:",inline"` // Diffusers - Diffusers Diffusers `yaml:"diffusers"` - Step int `yaml:"step"` + Diffusers Diffusers `yaml:"diffusers" json:"diffusers"` + Step int `yaml:"step" json:"step"` // GRPC Options - GRPC GRPC `yaml:"grpc"` + GRPC GRPC `yaml:"grpc" json:"grpc"` // TTS specifics - TTSConfig `yaml:"tts"` + TTSConfig `yaml:"tts" json:"tts"` // CUDA // Explicitly enable CUDA or not (some backends might need it) - CUDA bool `yaml:"cuda"` + CUDA bool `yaml:"cuda" json:"cuda"` - DownloadFiles []File `yaml:"download_files"` + DownloadFiles []File `yaml:"download_files" json:"download_files"` - Description string `yaml:"description"` - Usage string `yaml:"usage"` + Description string `yaml:"description" json:"description"` + Usage string `yaml:"usage" json:"usage"` - Options []string `yaml:"options"` - Overrides []string `yaml:"overrides"` + Options []string `yaml:"options" json:"options"` + Overrides []string `yaml:"overrides" json:"overrides"` } // Pipeline defines other models to use for audio-to-audio type Pipeline struct { - TTS string `yaml:"tts"` - LLM string `yaml:"llm"` - Transcription string `yaml:"transcription"` - VAD string `yaml:"vad"` + TTS string `yaml:"tts" json:"tts"` + LLM string `yaml:"llm" json:"llm"` + Transcription string `yaml:"transcription" json:"transcription"` + VAD string `yaml:"vad" json:"vad"` } type File struct { @@ -91,130 +92,132 @@ type File struct { type FeatureFlag map[string]*bool func (ff FeatureFlag) Enabled(s string) bool { - v, exist := ff[s] - return exist && v != nil && *v + if v, exists := ff[s]; exists && v != nil { + return *v + } + return false } type GRPC struct { - Attempts int `yaml:"attempts"` - AttemptsSleepTime int `yaml:"attempts_sleep_time"` + Attempts int `yaml:"attempts" json:"attempts"` + AttemptsSleepTime int `yaml:"attempts_sleep_time" json:"attempts_sleep_time"` } type Diffusers struct { - CUDA bool `yaml:"cuda"` - PipelineType string `yaml:"pipeline_type"` - SchedulerType string `yaml:"scheduler_type"` - EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify - IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser - ClipSkip int `yaml:"clip_skip"` // Skip every N frames - ClipModel string `yaml:"clip_model"` // Clip model to use - ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model - ControlNet string `yaml:"control_net"` + CUDA bool `yaml:"cuda" json:"cuda"` + PipelineType string `yaml:"pipeline_type" json:"pipeline_type"` + SchedulerType string `yaml:"scheduler_type" json:"scheduler_type"` + EnableParameters string `yaml:"enable_parameters" json:"enable_parameters"` // A list of comma separated parameters to specify + IMG2IMG bool `yaml:"img2img" json:"img2img"` // Image to Image Diffuser + ClipSkip int `yaml:"clip_skip" json:"clip_skip"` // Skip every N frames + ClipModel string `yaml:"clip_model" json:"clip_model"` // Clip model to use + ClipSubFolder string `yaml:"clip_subfolder" json:"clip_subfolder"` // Subfolder to use for clip model + ControlNet string `yaml:"control_net" json:"control_net"` } // LLMConfig is a struct that holds the configuration that are // generic for most of the LLM backends. type LLMConfig struct { - SystemPrompt string `yaml:"system_prompt"` - TensorSplit string `yaml:"tensor_split"` - MainGPU string `yaml:"main_gpu"` - RMSNormEps float32 `yaml:"rms_norm_eps"` - NGQA int32 `yaml:"ngqa"` - PromptCachePath string `yaml:"prompt_cache_path"` - PromptCacheAll bool `yaml:"prompt_cache_all"` - PromptCacheRO bool `yaml:"prompt_cache_ro"` - MirostatETA *float64 `yaml:"mirostat_eta"` - MirostatTAU *float64 `yaml:"mirostat_tau"` - Mirostat *int `yaml:"mirostat"` - NGPULayers *int `yaml:"gpu_layers"` - MMap *bool `yaml:"mmap"` - MMlock *bool `yaml:"mmlock"` - LowVRAM *bool `yaml:"low_vram"` - Reranking *bool `yaml:"reranking"` - Grammar string `yaml:"grammar"` - StopWords []string `yaml:"stopwords"` - Cutstrings []string `yaml:"cutstrings"` - ExtractRegex []string `yaml:"extract_regex"` - TrimSpace []string `yaml:"trimspace"` - TrimSuffix []string `yaml:"trimsuffix"` + SystemPrompt string `yaml:"system_prompt" json:"system_prompt"` + TensorSplit string `yaml:"tensor_split" json:"tensor_split"` + MainGPU string `yaml:"main_gpu" json:"main_gpu"` + RMSNormEps float32 `yaml:"rms_norm_eps" json:"rms_norm_eps"` + NGQA int32 `yaml:"ngqa" json:"ngqa"` + PromptCachePath string `yaml:"prompt_cache_path" json:"prompt_cache_path"` + PromptCacheAll bool `yaml:"prompt_cache_all" json:"prompt_cache_all"` + PromptCacheRO bool `yaml:"prompt_cache_ro" json:"prompt_cache_ro"` + MirostatETA *float64 `yaml:"mirostat_eta" json:"mirostat_eta"` + MirostatTAU *float64 `yaml:"mirostat_tau" json:"mirostat_tau"` + Mirostat *int `yaml:"mirostat" json:"mirostat"` + NGPULayers *int `yaml:"gpu_layers" json:"gpu_layers"` + MMap *bool `yaml:"mmap" json:"mmap"` + MMlock *bool `yaml:"mmlock" json:"mmlock"` + LowVRAM *bool `yaml:"low_vram" json:"low_vram"` + Reranking *bool `yaml:"reranking" json:"reranking"` + Grammar string `yaml:"grammar" json:"grammar"` + StopWords []string `yaml:"stopwords" json:"stopwords"` + Cutstrings []string `yaml:"cutstrings" json:"cutstrings"` + ExtractRegex []string `yaml:"extract_regex" json:"extract_regex"` + TrimSpace []string `yaml:"trimspace" json:"trimspace"` + TrimSuffix []string `yaml:"trimsuffix" json:"trimsuffix"` - ContextSize *int `yaml:"context_size"` - NUMA bool `yaml:"numa"` - LoraAdapter string `yaml:"lora_adapter"` - LoraBase string `yaml:"lora_base"` - LoraAdapters []string `yaml:"lora_adapters"` - LoraScales []float32 `yaml:"lora_scales"` - LoraScale float32 `yaml:"lora_scale"` - NoMulMatQ bool `yaml:"no_mulmatq"` - DraftModel string `yaml:"draft_model"` - NDraft int32 `yaml:"n_draft"` - Quantization string `yaml:"quantization"` - LoadFormat string `yaml:"load_format"` - GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM - TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM - EnforceEager bool `yaml:"enforce_eager"` // vLLM - SwapSpace int `yaml:"swap_space"` // vLLM - MaxModelLen int `yaml:"max_model_len"` // vLLM - TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM - DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM - DType string `yaml:"dtype"` // vLLM - LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM - MMProj string `yaml:"mmproj"` + ContextSize *int `yaml:"context_size" json:"context_size"` + NUMA bool `yaml:"numa" json:"numa"` + LoraAdapter string `yaml:"lora_adapter" json:"lora_adapter"` + LoraBase string `yaml:"lora_base" json:"lora_base"` + LoraAdapters []string `yaml:"lora_adapters" json:"lora_adapters"` + LoraScales []float32 `yaml:"lora_scales" json:"lora_scales"` + LoraScale float32 `yaml:"lora_scale" json:"lora_scale"` + NoMulMatQ bool `yaml:"no_mulmatq" json:"no_mulmatq"` + DraftModel string `yaml:"draft_model" json:"draft_model"` + NDraft int32 `yaml:"n_draft" json:"n_draft"` + Quantization string `yaml:"quantization" json:"quantization"` + LoadFormat string `yaml:"load_format" json:"load_format"` + GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization" json:"gpu_memory_utilization"` // vLLM + TrustRemoteCode bool `yaml:"trust_remote_code" json:"trust_remote_code"` // vLLM + EnforceEager bool `yaml:"enforce_eager" json:"enforce_eager"` // vLLM + SwapSpace int `yaml:"swap_space" json:"swap_space"` // vLLM + MaxModelLen int `yaml:"max_model_len" json:"max_model_len"` // vLLM + TensorParallelSize int `yaml:"tensor_parallel_size" json:"tensor_parallel_size"` // vLLM + DisableLogStatus bool `yaml:"disable_log_stats" json:"disable_log_stats"` // vLLM + DType string `yaml:"dtype" json:"dtype"` // vLLM + LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt" json:"limit_mm_per_prompt"` // vLLM + MMProj string `yaml:"mmproj" json:"mmproj"` - FlashAttention bool `yaml:"flash_attention"` - NoKVOffloading bool `yaml:"no_kv_offloading"` - CacheTypeK string `yaml:"cache_type_k"` - CacheTypeV string `yaml:"cache_type_v"` + FlashAttention bool `yaml:"flash_attention" json:"flash_attention"` + NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"` + CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"` + CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"` - RopeScaling string `yaml:"rope_scaling"` - ModelType string `yaml:"type"` + RopeScaling string `yaml:"rope_scaling" json:"rope_scaling"` + ModelType string `yaml:"type" json:"type"` - YarnExtFactor float32 `yaml:"yarn_ext_factor"` - YarnAttnFactor float32 `yaml:"yarn_attn_factor"` - YarnBetaFast float32 `yaml:"yarn_beta_fast"` - YarnBetaSlow float32 `yaml:"yarn_beta_slow"` + YarnExtFactor float32 `yaml:"yarn_ext_factor" json:"yarn_ext_factor"` + YarnAttnFactor float32 `yaml:"yarn_attn_factor" json:"yarn_attn_factor"` + YarnBetaFast float32 `yaml:"yarn_beta_fast" json:"yarn_beta_fast"` + YarnBetaSlow float32 `yaml:"yarn_beta_slow" json:"yarn_beta_slow"` - CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale + CFGScale float32 `yaml:"cfg_scale" json:"cfg_scale"` // Classifier-Free Guidance Scale } // LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM type LimitMMPerPrompt struct { - LimitImagePerPrompt int `yaml:"image"` - LimitVideoPerPrompt int `yaml:"video"` - LimitAudioPerPrompt int `yaml:"audio"` + LimitImagePerPrompt int `yaml:"image" json:"image"` + LimitVideoPerPrompt int `yaml:"video" json:"video"` + LimitAudioPerPrompt int `yaml:"audio" json:"audio"` } // TemplateConfig is a struct that holds the configuration of the templating system type TemplateConfig struct { // Chat is the template used in the chat completion endpoint - Chat string `yaml:"chat"` + Chat string `yaml:"chat" json:"chat"` // ChatMessage is the template used for chat messages - ChatMessage string `yaml:"chat_message"` + ChatMessage string `yaml:"chat_message" json:"chat_message"` // Completion is the template used for completion requests - Completion string `yaml:"completion"` + Completion string `yaml:"completion" json:"completion"` // Edit is the template used for edit completion requests - Edit string `yaml:"edit"` + Edit string `yaml:"edit" json:"edit"` // Functions is the template used when tools are present in the client requests - Functions string `yaml:"function"` + Functions string `yaml:"function" json:"function"` // UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used. // Note: this is mostly consumed for backends such as vllm and transformers // that can use the tokenizers specified in the JSON config files of the models - UseTokenizerTemplate bool `yaml:"use_tokenizer_template"` + UseTokenizerTemplate bool `yaml:"use_tokenizer_template" json:"use_tokenizer_template"` // JoinChatMessagesByCharacter is a string that will be used to join chat messages together. // It defaults to \n - JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"` + JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character" json:"join_chat_messages_by_character"` - Multimodal string `yaml:"multimodal"` + Multimodal string `yaml:"multimodal" json:"multimodal"` - JinjaTemplate bool `yaml:"jinja_template"` + JinjaTemplate bool `yaml:"jinja_template" json:"jinja_template"` - ReplyPrefix string `yaml:"reply_prefix"` + ReplyPrefix string `yaml:"reply_prefix" json:"reply_prefix"` } func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error { diff --git a/core/gallery/models.go b/core/gallery/models.go index f161e0cf3..57d07be57 100644 --- a/core/gallery/models.go +++ b/core/gallery/models.go @@ -316,17 +316,12 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error { return fmt.Errorf("failed to verify path %s: %w", galleryFile, err) } + var filesToRemove []string + // Delete all the files associated to the model // read the model config galleryconfig, err := ReadConfigFile[ModelConfig](galleryFile) - if err != nil { - log.Error().Err(err).Msgf("failed to read gallery file %s", configFile) - } - - var filesToRemove []string - - // Remove additional files - if galleryconfig != nil { + if err == nil && galleryconfig != nil { for _, f := range galleryconfig.Files { fullPath := filepath.Join(systemState.Model.ModelsPath, f.Filename) if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil { @@ -334,6 +329,8 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error { } filesToRemove = append(filesToRemove, fullPath) } + } else { + log.Error().Err(err).Msgf("failed to read gallery file %s", configFile) } for _, f := range additionalFiles { @@ -344,7 +341,6 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error { filesToRemove = append(filesToRemove, fullPath) } - filesToRemove = append(filesToRemove, configFile) filesToRemove = append(filesToRemove, galleryFile) // skip duplicates @@ -353,11 +349,11 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error { // Removing files for _, f := range filesToRemove { if e := os.Remove(f); e != nil { - err = errors.Join(err, fmt.Errorf("failed to remove file %s: %w", f, e)) + log.Error().Err(e).Msgf("failed to remove file %s", f) } } - return err + return os.Remove(configFile) } // This is ***NEVER*** going to be perfect or finished. diff --git a/core/http/endpoints/localai/edit_model.go b/core/http/endpoints/localai/edit_model.go new file mode 100644 index 000000000..c040ee1d6 --- /dev/null +++ b/core/http/endpoints/localai/edit_model.go @@ -0,0 +1,228 @@ +package localai + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/gofiber/fiber/v2" + "github.com/mudler/LocalAI/core/config" + httpUtils "github.com/mudler/LocalAI/core/http/utils" + "github.com/mudler/LocalAI/internal" + "github.com/mudler/LocalAI/pkg/utils" + + "gopkg.in/yaml.v3" +) + +// GetEditModelPage renders the edit model page with current configuration +func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler { + return func(c *fiber.Ctx) error { + modelName := c.Params("name") + if modelName == "" { + response := ModelResponse{ + Success: false, + Error: "Model name is required", + } + return c.Status(400).JSON(response) + } + + modelConfig, exists := cl.GetModelConfig(modelName) + if !exists { + response := ModelResponse{ + Success: false, + Error: "Model configuration not found", + } + return c.Status(404).JSON(response) + } + + configData, err := yaml.Marshal(modelConfig) + if err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to marshal configuration: " + err.Error(), + } + return c.Status(500).JSON(response) + } + + // Marshal the config to JSON for the template + configJSON, err := json.Marshal(modelConfig) + if err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to marshal configuration: " + err.Error(), + } + return c.Status(500).JSON(response) + } + + // Render the edit page with the current configuration + templateData := struct { + Title string + ModelName string + Config *config.ModelConfig + ConfigJSON string + ConfigYAML string + BaseURL string + Version string + }{ + Title: "LocalAI - Edit Model " + modelName, + ModelName: modelName, + Config: &modelConfig, + ConfigJSON: string(configJSON), + ConfigYAML: string(configData), + BaseURL: httpUtils.BaseURL(c), + Version: internal.PrintableVersion(), + } + + return c.Render("views/model-editor", templateData) + } +} + +// EditModelEndpoint handles updating existing model configurations +func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler { + return func(c *fiber.Ctx) error { + modelName := c.Params("name") + if modelName == "" { + response := ModelResponse{ + Success: false, + Error: "Model name is required", + } + return c.Status(400).JSON(response) + } + + // Get the raw body + body := c.Body() + if len(body) == 0 { + response := ModelResponse{ + Success: false, + Error: "Request body is empty", + } + return c.Status(400).JSON(response) + } + + // Check content type to determine how to parse + contentType := string(c.Context().Request.Header.ContentType()) + var req config.ModelConfig + var err error + + if strings.Contains(contentType, "application/json") { + // Parse JSON + if err := json.Unmarshal(body, &req); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to parse JSON: " + err.Error(), + } + return c.Status(400).JSON(response) + } + } else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") { + // Parse YAML + if err := yaml.Unmarshal(body, &req); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to parse YAML: " + err.Error(), + } + return c.Status(400).JSON(response) + } + } else { + // Try to auto-detect format + if strings.TrimSpace(string(body))[0] == '{' { + // Looks like JSON + if err := json.Unmarshal(body, &req); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to parse JSON: " + err.Error(), + } + return c.Status(400).JSON(response) + } + } else { + // Assume YAML + if err := yaml.Unmarshal(body, &req); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to parse YAML: " + err.Error(), + } + return c.Status(400).JSON(response) + } + } + } + + // Validate required fields + if req.Name == "" { + response := ModelResponse{ + Success: false, + Error: "Name is required", + } + return c.Status(400).JSON(response) + } + + // Load the existing configuration + configPath := filepath.Join(appConfig.SystemState.Model.ModelsPath, modelName+".yaml") + if err := utils.InTrustedRoot(configPath, appConfig.SystemState.Model.ModelsPath); err != nil { + response := ModelResponse{ + Success: false, + Error: "Model configuration not trusted: " + err.Error(), + } + return c.Status(404).JSON(response) + } + + // Set defaults + req.SetDefaults() + + // Validate the configuration + if !req.Validate() { + response := ModelResponse{ + Success: false, + Error: "Validation failed", + Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."}, + } + return c.Status(400).JSON(response) + } + + // Create the YAML file + yamlData, err := yaml.Marshal(req) + if err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to marshal configuration: " + err.Error(), + } + return c.Status(500).JSON(response) + } + + // Write to file + if err := os.WriteFile(configPath, yamlData, 0644); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to write configuration file: " + err.Error(), + } + return c.Status(500).JSON(response) + } + + // Reload configurations + if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to reload configurations: " + err.Error(), + } + return c.Status(500).JSON(response) + } + + // Preload the model + if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to preload model: " + err.Error(), + } + return c.Status(500).JSON(response) + } + + // Return success response + response := ModelResponse{ + Success: true, + Message: fmt.Sprintf("Model '%s' updated successfully", modelName), + Filename: configPath, + Config: req, + } + return c.JSON(response) + } +} diff --git a/core/http/endpoints/localai/edit_model_test.go b/core/http/endpoints/localai/edit_model_test.go new file mode 100644 index 000000000..813e91301 --- /dev/null +++ b/core/http/endpoints/localai/edit_model_test.go @@ -0,0 +1,72 @@ +package localai_test + +import ( + "bytes" + "io" + "net/http/httptest" + "os" + "path/filepath" + + "github.com/gofiber/fiber/v2" + "github.com/mudler/LocalAI/core/config" + . "github.com/mudler/LocalAI/core/http/endpoints/localai" + "github.com/mudler/LocalAI/pkg/system" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Edit Model test", func() { + + var tempDir string + BeforeEach(func() { + var err error + tempDir, err = os.MkdirTemp("", "localai-test") + Expect(err).ToNot(HaveOccurred()) + }) + AfterEach(func() { + os.RemoveAll(tempDir) + }) + + Context("Edit Model endpoint", func() { + It("should edit a model", func() { + systemState, err := system.GetSystemState( + system.WithModelPath(filepath.Join(tempDir)), + ) + Expect(err).ToNot(HaveOccurred()) + + applicationConfig := config.NewApplicationConfig( + config.WithSystemState(systemState), + ) + //modelLoader := model.NewModelLoader(systemState, true) + modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath) + + // Define Fiber app. + app := fiber.New() + app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig)) + + requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`) + + req := httptest.NewRequest("PUT", "/import-model", requestBody) + resp, err := app.Test(req, 5000) + Expect(err).ToNot(HaveOccurred()) + + body, err := io.ReadAll(resp.Body) + defer resp.Body.Close() + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(ContainSubstring("Model configuration created successfully")) + Expect(resp.StatusCode).To(Equal(fiber.StatusOK)) + + app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig)) + requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`) + + req = httptest.NewRequest("GET", "/edit-model/foo", requestBody) + resp, _ = app.Test(req, 1) + + body, err = io.ReadAll(resp.Body) + defer resp.Body.Close() + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(ContainSubstring(`"model":"foo"`)) + Expect(resp.StatusCode).To(Equal(fiber.StatusOK)) + }) + }) +}) diff --git a/core/http/endpoints/localai/import_model.go b/core/http/endpoints/localai/import_model.go new file mode 100644 index 000000000..bdb84334d --- /dev/null +++ b/core/http/endpoints/localai/import_model.go @@ -0,0 +1,148 @@ +package localai + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + + "github.com/gofiber/fiber/v2" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/utils" + "gopkg.in/yaml.v3" +) + +// ImportModelEndpoint handles creating new model configurations +func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler { + return func(c *fiber.Ctx) error { + // Get the raw body + body := c.Body() + if len(body) == 0 { + response := ModelResponse{ + Success: false, + Error: "Request body is empty", + } + return c.Status(400).JSON(response) + } + + // Check content type to determine how to parse + contentType := string(c.Context().Request.Header.ContentType()) + var modelConfig config.ModelConfig + var err error + + if strings.Contains(contentType, "application/json") { + // Parse JSON + if err := json.Unmarshal(body, &modelConfig); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to parse JSON: " + err.Error(), + } + return c.Status(400).JSON(response) + } + } else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") { + // Parse YAML + if err := yaml.Unmarshal(body, &modelConfig); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to parse YAML: " + err.Error(), + } + return c.Status(400).JSON(response) + } + } else { + // Try to auto-detect format + if strings.TrimSpace(string(body))[0] == '{' { + // Looks like JSON + if err := json.Unmarshal(body, &modelConfig); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to parse JSON: " + err.Error(), + } + return c.Status(400).JSON(response) + } + } else { + // Assume YAML + if err := yaml.Unmarshal(body, &modelConfig); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to parse YAML: " + err.Error(), + } + return c.Status(400).JSON(response) + } + } + } + + // Validate required fields + if modelConfig.Name == "" { + response := ModelResponse{ + Success: false, + Error: "Name is required", + } + return c.Status(400).JSON(response) + } + + // Set defaults + modelConfig.SetDefaults() + + // Validate the configuration + if !modelConfig.Validate() { + response := ModelResponse{ + Success: false, + Error: "Invalid configuration", + } + return c.Status(400).JSON(response) + } + + // Create the configuration file + configPath := filepath.Join(appConfig.SystemState.Model.ModelsPath, modelConfig.Name+".yaml") + if err := utils.InTrustedRoot(configPath, appConfig.SystemState.Model.ModelsPath); err != nil { + response := ModelResponse{ + Success: false, + Error: "Model path not trusted: " + err.Error(), + } + return c.Status(400).JSON(response) + } + + // Marshal to YAML for storage + yamlData, err := yaml.Marshal(&modelConfig) + if err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to marshal configuration: " + err.Error(), + } + return c.Status(500).JSON(response) + } + + // Write the file + if err := os.WriteFile(configPath, yamlData, 0644); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to write configuration file: " + err.Error(), + } + return c.Status(500).JSON(response) + } + // Reload configurations + if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to reload configurations: " + err.Error(), + } + return c.Status(500).JSON(response) + } + + // Preload the model + if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil { + response := ModelResponse{ + Success: false, + Error: "Failed to preload model: " + err.Error(), + } + return c.Status(500).JSON(response) + } + // Return success response + response := ModelResponse{ + Success: true, + Message: "Model configuration created successfully", + Filename: filepath.Base(configPath), + } + return c.JSON(response) + } +} diff --git a/core/http/endpoints/localai/localai_suite_test.go b/core/http/endpoints/localai/localai_suite_test.go new file mode 100644 index 000000000..ea415bf70 --- /dev/null +++ b/core/http/endpoints/localai/localai_suite_test.go @@ -0,0 +1,13 @@ +package localai_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestLocalAIEndpoints(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "LocalAI Endpoints test suite") +} diff --git a/core/http/endpoints/localai/types.go b/core/http/endpoints/localai/types.go new file mode 100644 index 000000000..32a549089 --- /dev/null +++ b/core/http/endpoints/localai/types.go @@ -0,0 +1,11 @@ +package localai + +// ModelResponse represents the common response structure for model operations +type ModelResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + Filename string `json:"filename,omitempty"` + Config interface{} `json:"config,omitempty"` + Error string `json:"error,omitempty"` + Details []string `json:"details,omitempty"` +} diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 54b7021f1..9a7c1d868 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -6,6 +6,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/middleware" + httpUtils "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" @@ -23,6 +24,17 @@ func RegisterLocalAIRoutes(router *fiber.App, // LocalAI API endpoints if !appConfig.DisableGalleryEndpoint { + // Import model page + router.Get("/import-model", func(c *fiber.Ctx) error { + return c.Render("views/model-editor", fiber.Map{ + "Title": "LocalAI - Import Model", + "BaseURL": httpUtils.BaseURL(c), + "Version": internal.PrintableVersion(), + }) + }) + + // Edit model page + router.Get("/models/edit/:name", localai.GetEditModelPage(cl, appConfig)) modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService) router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint()) router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint()) @@ -42,6 +54,11 @@ func RegisterLocalAIRoutes(router *fiber.App, router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState)) router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint()) router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint()) + // Custom model import endpoint + router.Post("/models/import", localai.ImportModelEndpoint(cl, appConfig)) + + // Custom model edit endpoint + router.Post("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig)) } router.Post("/v1/detection", diff --git a/core/http/views/index.html b/core/http/views/index.html index a6ab0d8f9..d90c899e3 100644 --- a/core/http/views/index.html +++ b/core/http/views/index.html @@ -30,6 +30,12 @@ Gallery + + + Import Custom Model + + @@ -143,7 +149,11 @@ {{ end }} -
+ + +
+ + Cannot edit (no config) + +
diff --git a/core/http/views/model-editor.html b/core/http/views/model-editor.html new file mode 100644 index 000000000..ba4c4befd --- /dev/null +++ b/core/http/views/model-editor.html @@ -0,0 +1,1022 @@ + + +{{template "views/partials/head" .}} + + +
+ + {{template "views/partials/navbar" .}} + +
+ +
+
+
+

+ {{if .ModelName}}Edit Model: {{.ModelName}}{{else}}Import New Model{{end}} +

+

Configure your model settings using the form or YAML editor

+
+
+ + +
+
+
+ + +
+ + +
+ + +
+
+

+ + Configuration Form +

+
+ + +
+
+
+
+ +
+
+
+ + +
+
+

+ + YAML Editor +

+
+ + +
+
+
+
+
+
+
+
+ + {{template "views/partials/footer" .}} +
+ + + + + + + + + + + + + + + +