feat(webui): add import/edit model page (#6050)

* feat(webui): add import/edit model page

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Convert to a YAML editor

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Pass by the baseurl

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Simplify

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Improve visibility of the yaml editor

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add test file

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Make reset work

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Emit error only if we can't delete the model yaml file

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-08-14 23:48:09 +02:00
committed by GitHub
parent 089efe05fd
commit 7050c9f69d
10 changed files with 1655 additions and 129 deletions

View File

@@ -19,67 +19,68 @@ const (
type TTSConfig struct {
// Voice wav path or id
Voice string `yaml:"voice"`
Voice string `yaml:"voice" json:"voice"`
AudioPath string `yaml:"audio_path"`
AudioPath string `yaml:"audio_path" json:"audio_path"`
}
// ModelConfig represents a model configuration
type ModelConfig struct {
schema.PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
schema.PredictionOptions `yaml:"parameters" json:"parameters"`
Name string `yaml:"name" json:"name"`
F16 *bool `yaml:"f16"`
Threads *int `yaml:"threads"`
Debug *bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings *bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
KnownUsecaseStrings []string `yaml:"known_usecases"`
KnownUsecases *ModelConfigUsecases `yaml:"-"`
Pipeline Pipeline `yaml:"pipeline"`
F16 *bool `yaml:"f16" json:"f16"`
Threads *int `yaml:"threads" json:"threads"`
Debug *bool `yaml:"debug" json:"debug"`
Roles map[string]string `yaml:"roles" json:"roles"`
Embeddings *bool `yaml:"embeddings" json:"embeddings"`
Backend string `yaml:"backend" json:"backend"`
TemplateConfig TemplateConfig `yaml:"template" json:"template"`
KnownUsecaseStrings []string `yaml:"known_usecases" json:"known_usecases"`
KnownUsecases *ModelConfigUsecases `yaml:"-" json:"-"`
Pipeline Pipeline `yaml:"pipeline" json:"pipeline"`
PromptStrings, InputStrings []string `yaml:"-"`
InputToken [][]int `yaml:"-"`
functionCallString, functionCallNameString string `yaml:"-"`
ResponseFormat string `yaml:"-"`
ResponseFormatMap map[string]interface{} `yaml:"-"`
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
InputToken [][]int `yaml:"-" json:"-"`
functionCallString, functionCallNameString string `yaml:"-" json:"-"`
ResponseFormat string `yaml:"-" json:"-"`
ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"`
FunctionsConfig functions.FunctionsConfig `yaml:"function"`
FunctionsConfig functions.FunctionsConfig `yaml:"function" json:"function"`
FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
FeatureFlag FeatureFlag `yaml:"feature_flags" json:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig `yaml:",inline"`
LLMConfig `yaml:",inline" json:",inline"`
// Diffusers
Diffusers Diffusers `yaml:"diffusers"`
Step int `yaml:"step"`
Diffusers Diffusers `yaml:"diffusers" json:"diffusers"`
Step int `yaml:"step" json:"step"`
// GRPC Options
GRPC GRPC `yaml:"grpc"`
GRPC GRPC `yaml:"grpc" json:"grpc"`
// TTS specifics
TTSConfig `yaml:"tts"`
TTSConfig `yaml:"tts" json:"tts"`
// CUDA
// Explicitly enable CUDA or not (some backends might need it)
CUDA bool `yaml:"cuda"`
CUDA bool `yaml:"cuda" json:"cuda"`
DownloadFiles []File `yaml:"download_files"`
DownloadFiles []File `yaml:"download_files" json:"download_files"`
Description string `yaml:"description"`
Usage string `yaml:"usage"`
Description string `yaml:"description" json:"description"`
Usage string `yaml:"usage" json:"usage"`
Options []string `yaml:"options"`
Overrides []string `yaml:"overrides"`
Options []string `yaml:"options" json:"options"`
Overrides []string `yaml:"overrides" json:"overrides"`
}
// Pipeline defines other models to use for audio-to-audio
type Pipeline struct {
TTS string `yaml:"tts"`
LLM string `yaml:"llm"`
Transcription string `yaml:"transcription"`
VAD string `yaml:"vad"`
TTS string `yaml:"tts" json:"tts"`
LLM string `yaml:"llm" json:"llm"`
Transcription string `yaml:"transcription" json:"transcription"`
VAD string `yaml:"vad" json:"vad"`
}
type File struct {
@@ -91,130 +92,132 @@ type File struct {
type FeatureFlag map[string]*bool
func (ff FeatureFlag) Enabled(s string) bool {
v, exist := ff[s]
return exist && v != nil && *v
if v, exists := ff[s]; exists && v != nil {
return *v
}
return false
}
type GRPC struct {
Attempts int `yaml:"attempts"`
AttemptsSleepTime int `yaml:"attempts_sleep_time"`
Attempts int `yaml:"attempts" json:"attempts"`
AttemptsSleepTime int `yaml:"attempts_sleep_time" json:"attempts_sleep_time"`
}
type Diffusers struct {
CUDA bool `yaml:"cuda"`
PipelineType string `yaml:"pipeline_type"`
SchedulerType string `yaml:"scheduler_type"`
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
ClipModel string `yaml:"clip_model"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net"`
CUDA bool `yaml:"cuda" json:"cuda"`
PipelineType string `yaml:"pipeline_type" json:"pipeline_type"`
SchedulerType string `yaml:"scheduler_type" json:"scheduler_type"`
EnableParameters string `yaml:"enable_parameters" json:"enable_parameters"` // A list of comma separated parameters to specify
IMG2IMG bool `yaml:"img2img" json:"img2img"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip" json:"clip_skip"` // Skip every N frames
ClipModel string `yaml:"clip_model" json:"clip_model"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder" json:"clip_subfolder"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net" json:"control_net"`
}
// LLMConfig is a struct that holds the configuration that are
// generic for most of the LLM backends.
type LLMConfig struct {
SystemPrompt string `yaml:"system_prompt"`
TensorSplit string `yaml:"tensor_split"`
MainGPU string `yaml:"main_gpu"`
RMSNormEps float32 `yaml:"rms_norm_eps"`
NGQA int32 `yaml:"ngqa"`
PromptCachePath string `yaml:"prompt_cache_path"`
PromptCacheAll bool `yaml:"prompt_cache_all"`
PromptCacheRO bool `yaml:"prompt_cache_ro"`
MirostatETA *float64 `yaml:"mirostat_eta"`
MirostatTAU *float64 `yaml:"mirostat_tau"`
Mirostat *int `yaml:"mirostat"`
NGPULayers *int `yaml:"gpu_layers"`
MMap *bool `yaml:"mmap"`
MMlock *bool `yaml:"mmlock"`
LowVRAM *bool `yaml:"low_vram"`
Reranking *bool `yaml:"reranking"`
Grammar string `yaml:"grammar"`
StopWords []string `yaml:"stopwords"`
Cutstrings []string `yaml:"cutstrings"`
ExtractRegex []string `yaml:"extract_regex"`
TrimSpace []string `yaml:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix"`
SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
TensorSplit string `yaml:"tensor_split" json:"tensor_split"`
MainGPU string `yaml:"main_gpu" json:"main_gpu"`
RMSNormEps float32 `yaml:"rms_norm_eps" json:"rms_norm_eps"`
NGQA int32 `yaml:"ngqa" json:"ngqa"`
PromptCachePath string `yaml:"prompt_cache_path" json:"prompt_cache_path"`
PromptCacheAll bool `yaml:"prompt_cache_all" json:"prompt_cache_all"`
PromptCacheRO bool `yaml:"prompt_cache_ro" json:"prompt_cache_ro"`
MirostatETA *float64 `yaml:"mirostat_eta" json:"mirostat_eta"`
MirostatTAU *float64 `yaml:"mirostat_tau" json:"mirostat_tau"`
Mirostat *int `yaml:"mirostat" json:"mirostat"`
NGPULayers *int `yaml:"gpu_layers" json:"gpu_layers"`
MMap *bool `yaml:"mmap" json:"mmap"`
MMlock *bool `yaml:"mmlock" json:"mmlock"`
LowVRAM *bool `yaml:"low_vram" json:"low_vram"`
Reranking *bool `yaml:"reranking" json:"reranking"`
Grammar string `yaml:"grammar" json:"grammar"`
StopWords []string `yaml:"stopwords" json:"stopwords"`
Cutstrings []string `yaml:"cutstrings" json:"cutstrings"`
ExtractRegex []string `yaml:"extract_regex" json:"extract_regex"`
TrimSpace []string `yaml:"trimspace" json:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix" json:"trimsuffix"`
ContextSize *int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraAdapters []string `yaml:"lora_adapters"`
LoraScales []float32 `yaml:"lora_scales"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
LoadFormat string `yaml:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM
DType string `yaml:"dtype"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM
MMProj string `yaml:"mmproj"`
ContextSize *int `yaml:"context_size" json:"context_size"`
NUMA bool `yaml:"numa" json:"numa"`
LoraAdapter string `yaml:"lora_adapter" json:"lora_adapter"`
LoraBase string `yaml:"lora_base" json:"lora_base"`
LoraAdapters []string `yaml:"lora_adapters" json:"lora_adapters"`
LoraScales []float32 `yaml:"lora_scales" json:"lora_scales"`
LoraScale float32 `yaml:"lora_scale" json:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq" json:"no_mulmatq"`
DraftModel string `yaml:"draft_model" json:"draft_model"`
NDraft int32 `yaml:"n_draft" json:"n_draft"`
Quantization string `yaml:"quantization" json:"quantization"`
LoadFormat string `yaml:"load_format" json:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization" json:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code" json:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager" json:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space" json:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len" json:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size" json:"tensor_parallel_size"` // vLLM
DisableLogStatus bool `yaml:"disable_log_stats" json:"disable_log_stats"` // vLLM
DType string `yaml:"dtype" json:"dtype"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt" json:"limit_mm_per_prompt"` // vLLM
MMProj string `yaml:"mmproj" json:"mmproj"`
FlashAttention bool `yaml:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading"`
CacheTypeK string `yaml:"cache_type_k"`
CacheTypeV string `yaml:"cache_type_v"`
FlashAttention bool `yaml:"flash_attention" json:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
RopeScaling string `yaml:"rope_scaling"`
ModelType string `yaml:"type"`
RopeScaling string `yaml:"rope_scaling" json:"rope_scaling"`
ModelType string `yaml:"type" json:"type"`
YarnExtFactor float32 `yaml:"yarn_ext_factor"`
YarnAttnFactor float32 `yaml:"yarn_attn_factor"`
YarnBetaFast float32 `yaml:"yarn_beta_fast"`
YarnBetaSlow float32 `yaml:"yarn_beta_slow"`
YarnExtFactor float32 `yaml:"yarn_ext_factor" json:"yarn_ext_factor"`
YarnAttnFactor float32 `yaml:"yarn_attn_factor" json:"yarn_attn_factor"`
YarnBetaFast float32 `yaml:"yarn_beta_fast" json:"yarn_beta_fast"`
YarnBetaSlow float32 `yaml:"yarn_beta_slow" json:"yarn_beta_slow"`
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
CFGScale float32 `yaml:"cfg_scale" json:"cfg_scale"` // Classifier-Free Guidance Scale
}
// LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
type LimitMMPerPrompt struct {
LimitImagePerPrompt int `yaml:"image"`
LimitVideoPerPrompt int `yaml:"video"`
LimitAudioPerPrompt int `yaml:"audio"`
LimitImagePerPrompt int `yaml:"image" json:"image"`
LimitVideoPerPrompt int `yaml:"video" json:"video"`
LimitAudioPerPrompt int `yaml:"audio" json:"audio"`
}
// TemplateConfig is a struct that holds the configuration of the templating system
type TemplateConfig struct {
// Chat is the template used in the chat completion endpoint
Chat string `yaml:"chat"`
Chat string `yaml:"chat" json:"chat"`
// ChatMessage is the template used for chat messages
ChatMessage string `yaml:"chat_message"`
ChatMessage string `yaml:"chat_message" json:"chat_message"`
// Completion is the template used for completion requests
Completion string `yaml:"completion"`
Completion string `yaml:"completion" json:"completion"`
// Edit is the template used for edit completion requests
Edit string `yaml:"edit"`
Edit string `yaml:"edit" json:"edit"`
// Functions is the template used when tools are present in the client requests
Functions string `yaml:"function"`
Functions string `yaml:"function" json:"function"`
// UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used.
// Note: this is mostly consumed for backends such as vllm and transformers
// that can use the tokenizers specified in the JSON config files of the models
UseTokenizerTemplate bool `yaml:"use_tokenizer_template"`
UseTokenizerTemplate bool `yaml:"use_tokenizer_template" json:"use_tokenizer_template"`
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
// It defaults to \n
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character" json:"join_chat_messages_by_character"`
Multimodal string `yaml:"multimodal"`
Multimodal string `yaml:"multimodal" json:"multimodal"`
JinjaTemplate bool `yaml:"jinja_template"`
JinjaTemplate bool `yaml:"jinja_template" json:"jinja_template"`
ReplyPrefix string `yaml:"reply_prefix"`
ReplyPrefix string `yaml:"reply_prefix" json:"reply_prefix"`
}
func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {

View File

@@ -316,17 +316,12 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
return fmt.Errorf("failed to verify path %s: %w", galleryFile, err)
}
var filesToRemove []string
// Delete all the files associated to the model
// read the model config
galleryconfig, err := ReadConfigFile[ModelConfig](galleryFile)
if err != nil {
log.Error().Err(err).Msgf("failed to read gallery file %s", configFile)
}
var filesToRemove []string
// Remove additional files
if galleryconfig != nil {
if err == nil && galleryconfig != nil {
for _, f := range galleryconfig.Files {
fullPath := filepath.Join(systemState.Model.ModelsPath, f.Filename)
if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil {
@@ -334,6 +329,8 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
}
filesToRemove = append(filesToRemove, fullPath)
}
} else {
log.Error().Err(err).Msgf("failed to read gallery file %s", configFile)
}
for _, f := range additionalFiles {
@@ -344,7 +341,6 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
filesToRemove = append(filesToRemove, fullPath)
}
filesToRemove = append(filesToRemove, configFile)
filesToRemove = append(filesToRemove, galleryFile)
// skip duplicates
@@ -353,11 +349,11 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
// Removing files
for _, f := range filesToRemove {
if e := os.Remove(f); e != nil {
err = errors.Join(err, fmt.Errorf("failed to remove file %s: %w", f, e))
log.Error().Err(e).Msgf("failed to remove file %s", f)
}
}
return err
return os.Remove(configFile)
}
// This is ***NEVER*** going to be perfect or finished.

View File

@@ -0,0 +1,228 @@
package localai
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
httpUtils "github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/utils"
"gopkg.in/yaml.v3"
)
// GetEditModelPage renders the edit model page with current configuration
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
if modelName == "" {
response := ModelResponse{
Success: false,
Error: "Model name is required",
}
return c.Status(400).JSON(response)
}
modelConfig, exists := cl.GetModelConfig(modelName)
if !exists {
response := ModelResponse{
Success: false,
Error: "Model configuration not found",
}
return c.Status(404).JSON(response)
}
configData, err := yaml.Marshal(modelConfig)
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to marshal configuration: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Marshal the config to JSON for the template
configJSON, err := json.Marshal(modelConfig)
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to marshal configuration: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Render the edit page with the current configuration
templateData := struct {
Title string
ModelName string
Config *config.ModelConfig
ConfigJSON string
ConfigYAML string
BaseURL string
Version string
}{
Title: "LocalAI - Edit Model " + modelName,
ModelName: modelName,
Config: &modelConfig,
ConfigJSON: string(configJSON),
ConfigYAML: string(configData),
BaseURL: httpUtils.BaseURL(c),
Version: internal.PrintableVersion(),
}
return c.Render("views/model-editor", templateData)
}
}
// EditModelEndpoint handles updating existing model configurations
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
if modelName == "" {
response := ModelResponse{
Success: false,
Error: "Model name is required",
}
return c.Status(400).JSON(response)
}
// Get the raw body
body := c.Body()
if len(body) == 0 {
response := ModelResponse{
Success: false,
Error: "Request body is empty",
}
return c.Status(400).JSON(response)
}
// Check content type to determine how to parse
contentType := string(c.Context().Request.Header.ContentType())
var req config.ModelConfig
var err error
if strings.Contains(contentType, "application/json") {
// Parse JSON
if err := json.Unmarshal(body, &req); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.Status(400).JSON(response)
}
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
// Parse YAML
if err := yaml.Unmarshal(body, &req); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
}
} else {
// Try to auto-detect format
if strings.TrimSpace(string(body))[0] == '{' {
// Looks like JSON
if err := json.Unmarshal(body, &req); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.Status(400).JSON(response)
}
} else {
// Assume YAML
if err := yaml.Unmarshal(body, &req); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
}
}
}
// Validate required fields
if req.Name == "" {
response := ModelResponse{
Success: false,
Error: "Name is required",
}
return c.Status(400).JSON(response)
}
// Load the existing configuration
configPath := filepath.Join(appConfig.SystemState.Model.ModelsPath, modelName+".yaml")
if err := utils.InTrustedRoot(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Model configuration not trusted: " + err.Error(),
}
return c.Status(404).JSON(response)
}
// Set defaults
req.SetDefaults()
// Validate the configuration
if !req.Validate() {
response := ModelResponse{
Success: false,
Error: "Validation failed",
Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."},
}
return c.Status(400).JSON(response)
}
// Create the YAML file
yamlData, err := yaml.Marshal(req)
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to marshal configuration: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Write to file
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to write configuration file: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Preload the model
if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to preload model: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Return success response
response := ModelResponse{
Success: true,
Message: fmt.Sprintf("Model '%s' updated successfully", modelName),
Filename: configPath,
Config: req,
}
return c.JSON(response)
}
}

View File

@@ -0,0 +1,72 @@
package localai_test
import (
"bytes"
"io"
"net/http/httptest"
"os"
"path/filepath"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/pkg/system"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Edit Model test", func() {
var tempDir string
BeforeEach(func() {
var err error
tempDir, err = os.MkdirTemp("", "localai-test")
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
os.RemoveAll(tempDir)
})
Context("Edit Model endpoint", func() {
It("should edit a model", func() {
systemState, err := system.GetSystemState(
system.WithModelPath(filepath.Join(tempDir)),
)
Expect(err).ToNot(HaveOccurred())
applicationConfig := config.NewApplicationConfig(
config.WithSystemState(systemState),
)
//modelLoader := model.NewModelLoader(systemState, true)
modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath)
// Define Fiber app.
app := fiber.New()
app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`)
req := httptest.NewRequest("PUT", "/import-model", requestBody)
resp, err := app.Test(req, 5000)
Expect(err).ToNot(HaveOccurred())
body, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(ContainSubstring("Model configuration created successfully"))
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig))
requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`)
req = httptest.NewRequest("GET", "/edit-model/foo", requestBody)
resp, _ = app.Test(req, 1)
body, err = io.ReadAll(resp.Body)
defer resp.Body.Close()
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(ContainSubstring(`"model":"foo"`))
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
})
})
})

View File

@@ -0,0 +1,148 @@
package localai
import (
"encoding/json"
"os"
"path/filepath"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/utils"
"gopkg.in/yaml.v3"
)
// ImportModelEndpoint handles creating new model configurations
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// Get the raw body
body := c.Body()
if len(body) == 0 {
response := ModelResponse{
Success: false,
Error: "Request body is empty",
}
return c.Status(400).JSON(response)
}
// Check content type to determine how to parse
contentType := string(c.Context().Request.Header.ContentType())
var modelConfig config.ModelConfig
var err error
if strings.Contains(contentType, "application/json") {
// Parse JSON
if err := json.Unmarshal(body, &modelConfig); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.Status(400).JSON(response)
}
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
// Parse YAML
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
}
} else {
// Try to auto-detect format
if strings.TrimSpace(string(body))[0] == '{' {
// Looks like JSON
if err := json.Unmarshal(body, &modelConfig); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.Status(400).JSON(response)
}
} else {
// Assume YAML
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.Status(400).JSON(response)
}
}
}
// Validate required fields
if modelConfig.Name == "" {
response := ModelResponse{
Success: false,
Error: "Name is required",
}
return c.Status(400).JSON(response)
}
// Set defaults
modelConfig.SetDefaults()
// Validate the configuration
if !modelConfig.Validate() {
response := ModelResponse{
Success: false,
Error: "Invalid configuration",
}
return c.Status(400).JSON(response)
}
// Create the configuration file
configPath := filepath.Join(appConfig.SystemState.Model.ModelsPath, modelConfig.Name+".yaml")
if err := utils.InTrustedRoot(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Model path not trusted: " + err.Error(),
}
return c.Status(400).JSON(response)
}
// Marshal to YAML for storage
yamlData, err := yaml.Marshal(&modelConfig)
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to marshal configuration: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Write the file
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to write configuration file: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Preload the model
if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to preload model: " + err.Error(),
}
return c.Status(500).JSON(response)
}
// Return success response
response := ModelResponse{
Success: true,
Message: "Model configuration created successfully",
Filename: filepath.Base(configPath),
}
return c.JSON(response)
}
}

View File

@@ -0,0 +1,13 @@
package localai_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestLocalAIEndpoints(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "LocalAI Endpoints test suite")
}

View File

@@ -0,0 +1,11 @@
package localai
// ModelResponse represents the common response structure for model operations
type ModelResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Filename string `json:"filename,omitempty"`
Config interface{} `json:"config,omitempty"`
Error string `json:"error,omitempty"`
Details []string `json:"details,omitempty"`
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
httpUtils "github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
@@ -23,6 +24,17 @@ func RegisterLocalAIRoutes(router *fiber.App,
// LocalAI API endpoints
if !appConfig.DisableGalleryEndpoint {
// Import model page
router.Get("/import-model", func(c *fiber.Ctx) error {
return c.Render("views/model-editor", fiber.Map{
"Title": "LocalAI - Import Model",
"BaseURL": httpUtils.BaseURL(c),
"Version": internal.PrintableVersion(),
})
})
// Edit model page
router.Get("/models/edit/:name", localai.GetEditModelPage(cl, appConfig))
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService)
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
@@ -42,6 +54,11 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
// Custom model import endpoint
router.Post("/models/import", localai.ImportModelEndpoint(cl, appConfig))
// Custom model edit endpoint
router.Post("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
}
router.Post("/v1/detection",

View File

@@ -30,6 +30,12 @@
<span>Gallery</span>
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
</a>
<a href="/import-model"
class="group flex items-center bg-green-600 hover:bg-green-700 text-white py-2 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 hover:shadow-lg">
<i class="fas fa-plus mr-2"></i>
<span>Import Custom Model</span>
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
</a>
</div>
</div>
</div>
@@ -143,7 +149,11 @@
{{ end }}
</div>
<div class="mt-4 flex justify-end">
<div class="mt-4 flex justify-end gap-2">
<a href="/models/edit/{{.Name}}"
class="inline-flex items-center text-xs font-medium text-indigo-400 hover:text-indigo-300 hover:bg-indigo-900/20 rounded-md px-2 py-1 transition-colors duration-200">
<i class="fas fa-edit mr-1.5"></i>Edit
</a>
<button
class="inline-flex items-center text-xs font-medium text-red-400 hover:text-red-300 hover:bg-red-900/20 rounded-md px-2 py-1 transition-colors duration-200"
data-twe-ripple-init=""
@@ -176,6 +186,12 @@
No Configuration
</span>
</div>
<div class="mt-4 flex justify-end">
<span class="inline-flex items-center text-xs font-medium text-gray-400 px-2 py-1">
<i class="fas fa-info-circle mr-1.5"></i>Cannot edit (no config)
</span>
</div>
</div>
</div>
</div>

View File

File diff suppressed because it is too large Load Diff