mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-04 23:14:41 -04:00
feat(webui): add import/edit model page (#6050)
* feat(webui): add import/edit model page Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Convert to a YAML editor Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Pass by the baseurl Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Simplify Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Improve visibility of the yaml editor Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add test file Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Make reset work Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Emit error only if we can't delete the model yaml file Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
089efe05fd
commit
7050c9f69d
@@ -19,67 +19,68 @@ const (
|
||||
type TTSConfig struct {
|
||||
|
||||
// Voice wav path or id
|
||||
Voice string `yaml:"voice"`
|
||||
Voice string `yaml:"voice" json:"voice"`
|
||||
|
||||
AudioPath string `yaml:"audio_path"`
|
||||
AudioPath string `yaml:"audio_path" json:"audio_path"`
|
||||
}
|
||||
|
||||
// ModelConfig represents a model configuration
|
||||
type ModelConfig struct {
|
||||
schema.PredictionOptions `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
schema.PredictionOptions `yaml:"parameters" json:"parameters"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
F16 *bool `yaml:"f16"`
|
||||
Threads *int `yaml:"threads"`
|
||||
Debug *bool `yaml:"debug"`
|
||||
Roles map[string]string `yaml:"roles"`
|
||||
Embeddings *bool `yaml:"embeddings"`
|
||||
Backend string `yaml:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template"`
|
||||
KnownUsecaseStrings []string `yaml:"known_usecases"`
|
||||
KnownUsecases *ModelConfigUsecases `yaml:"-"`
|
||||
Pipeline Pipeline `yaml:"pipeline"`
|
||||
F16 *bool `yaml:"f16" json:"f16"`
|
||||
Threads *int `yaml:"threads" json:"threads"`
|
||||
Debug *bool `yaml:"debug" json:"debug"`
|
||||
Roles map[string]string `yaml:"roles" json:"roles"`
|
||||
Embeddings *bool `yaml:"embeddings" json:"embeddings"`
|
||||
Backend string `yaml:"backend" json:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template" json:"template"`
|
||||
KnownUsecaseStrings []string `yaml:"known_usecases" json:"known_usecases"`
|
||||
KnownUsecases *ModelConfigUsecases `yaml:"-" json:"-"`
|
||||
Pipeline Pipeline `yaml:"pipeline" json:"pipeline"`
|
||||
|
||||
PromptStrings, InputStrings []string `yaml:"-"`
|
||||
InputToken [][]int `yaml:"-"`
|
||||
functionCallString, functionCallNameString string `yaml:"-"`
|
||||
ResponseFormat string `yaml:"-"`
|
||||
ResponseFormatMap map[string]interface{} `yaml:"-"`
|
||||
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
|
||||
InputToken [][]int `yaml:"-" json:"-"`
|
||||
functionCallString, functionCallNameString string `yaml:"-" json:"-"`
|
||||
ResponseFormat string `yaml:"-" json:"-"`
|
||||
ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"`
|
||||
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function"`
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function" json:"function"`
|
||||
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags" json:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
// LLM configs (GPT4ALL, Llama.cpp, ...)
|
||||
LLMConfig `yaml:",inline"`
|
||||
LLMConfig `yaml:",inline" json:",inline"`
|
||||
|
||||
// Diffusers
|
||||
Diffusers Diffusers `yaml:"diffusers"`
|
||||
Step int `yaml:"step"`
|
||||
Diffusers Diffusers `yaml:"diffusers" json:"diffusers"`
|
||||
Step int `yaml:"step" json:"step"`
|
||||
|
||||
// GRPC Options
|
||||
GRPC GRPC `yaml:"grpc"`
|
||||
GRPC GRPC `yaml:"grpc" json:"grpc"`
|
||||
|
||||
// TTS specifics
|
||||
TTSConfig `yaml:"tts"`
|
||||
TTSConfig `yaml:"tts" json:"tts"`
|
||||
|
||||
// CUDA
|
||||
// Explicitly enable CUDA or not (some backends might need it)
|
||||
CUDA bool `yaml:"cuda"`
|
||||
CUDA bool `yaml:"cuda" json:"cuda"`
|
||||
|
||||
DownloadFiles []File `yaml:"download_files"`
|
||||
DownloadFiles []File `yaml:"download_files" json:"download_files"`
|
||||
|
||||
Description string `yaml:"description"`
|
||||
Usage string `yaml:"usage"`
|
||||
Description string `yaml:"description" json:"description"`
|
||||
Usage string `yaml:"usage" json:"usage"`
|
||||
|
||||
Options []string `yaml:"options"`
|
||||
Overrides []string `yaml:"overrides"`
|
||||
Options []string `yaml:"options" json:"options"`
|
||||
Overrides []string `yaml:"overrides" json:"overrides"`
|
||||
}
|
||||
|
||||
// Pipeline defines other models to use for audio-to-audio
|
||||
type Pipeline struct {
|
||||
TTS string `yaml:"tts"`
|
||||
LLM string `yaml:"llm"`
|
||||
Transcription string `yaml:"transcription"`
|
||||
VAD string `yaml:"vad"`
|
||||
TTS string `yaml:"tts" json:"tts"`
|
||||
LLM string `yaml:"llm" json:"llm"`
|
||||
Transcription string `yaml:"transcription" json:"transcription"`
|
||||
VAD string `yaml:"vad" json:"vad"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
@@ -91,130 +92,132 @@ type File struct {
|
||||
type FeatureFlag map[string]*bool
|
||||
|
||||
func (ff FeatureFlag) Enabled(s string) bool {
|
||||
v, exist := ff[s]
|
||||
return exist && v != nil && *v
|
||||
if v, exists := ff[s]; exists && v != nil {
|
||||
return *v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type GRPC struct {
|
||||
Attempts int `yaml:"attempts"`
|
||||
AttemptsSleepTime int `yaml:"attempts_sleep_time"`
|
||||
Attempts int `yaml:"attempts" json:"attempts"`
|
||||
AttemptsSleepTime int `yaml:"attempts_sleep_time" json:"attempts_sleep_time"`
|
||||
}
|
||||
|
||||
type Diffusers struct {
|
||||
CUDA bool `yaml:"cuda"`
|
||||
PipelineType string `yaml:"pipeline_type"`
|
||||
SchedulerType string `yaml:"scheduler_type"`
|
||||
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
|
||||
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
|
||||
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net"`
|
||||
CUDA bool `yaml:"cuda" json:"cuda"`
|
||||
PipelineType string `yaml:"pipeline_type" json:"pipeline_type"`
|
||||
SchedulerType string `yaml:"scheduler_type" json:"scheduler_type"`
|
||||
EnableParameters string `yaml:"enable_parameters" json:"enable_parameters"` // A list of comma separated parameters to specify
|
||||
IMG2IMG bool `yaml:"img2img" json:"img2img"` // Image to Image Diffuser
|
||||
ClipSkip int `yaml:"clip_skip" json:"clip_skip"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model" json:"clip_model"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder" json:"clip_subfolder"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net" json:"control_net"`
|
||||
}
|
||||
|
||||
// LLMConfig is a struct that holds the configuration that are
|
||||
// generic for most of the LLM backends.
|
||||
type LLMConfig struct {
|
||||
SystemPrompt string `yaml:"system_prompt"`
|
||||
TensorSplit string `yaml:"tensor_split"`
|
||||
MainGPU string `yaml:"main_gpu"`
|
||||
RMSNormEps float32 `yaml:"rms_norm_eps"`
|
||||
NGQA int32 `yaml:"ngqa"`
|
||||
PromptCachePath string `yaml:"prompt_cache_path"`
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
||||
MirostatETA *float64 `yaml:"mirostat_eta"`
|
||||
MirostatTAU *float64 `yaml:"mirostat_tau"`
|
||||
Mirostat *int `yaml:"mirostat"`
|
||||
NGPULayers *int `yaml:"gpu_layers"`
|
||||
MMap *bool `yaml:"mmap"`
|
||||
MMlock *bool `yaml:"mmlock"`
|
||||
LowVRAM *bool `yaml:"low_vram"`
|
||||
Reranking *bool `yaml:"reranking"`
|
||||
Grammar string `yaml:"grammar"`
|
||||
StopWords []string `yaml:"stopwords"`
|
||||
Cutstrings []string `yaml:"cutstrings"`
|
||||
ExtractRegex []string `yaml:"extract_regex"`
|
||||
TrimSpace []string `yaml:"trimspace"`
|
||||
TrimSuffix []string `yaml:"trimsuffix"`
|
||||
SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
|
||||
TensorSplit string `yaml:"tensor_split" json:"tensor_split"`
|
||||
MainGPU string `yaml:"main_gpu" json:"main_gpu"`
|
||||
RMSNormEps float32 `yaml:"rms_norm_eps" json:"rms_norm_eps"`
|
||||
NGQA int32 `yaml:"ngqa" json:"ngqa"`
|
||||
PromptCachePath string `yaml:"prompt_cache_path" json:"prompt_cache_path"`
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all" json:"prompt_cache_all"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro" json:"prompt_cache_ro"`
|
||||
MirostatETA *float64 `yaml:"mirostat_eta" json:"mirostat_eta"`
|
||||
MirostatTAU *float64 `yaml:"mirostat_tau" json:"mirostat_tau"`
|
||||
Mirostat *int `yaml:"mirostat" json:"mirostat"`
|
||||
NGPULayers *int `yaml:"gpu_layers" json:"gpu_layers"`
|
||||
MMap *bool `yaml:"mmap" json:"mmap"`
|
||||
MMlock *bool `yaml:"mmlock" json:"mmlock"`
|
||||
LowVRAM *bool `yaml:"low_vram" json:"low_vram"`
|
||||
Reranking *bool `yaml:"reranking" json:"reranking"`
|
||||
Grammar string `yaml:"grammar" json:"grammar"`
|
||||
StopWords []string `yaml:"stopwords" json:"stopwords"`
|
||||
Cutstrings []string `yaml:"cutstrings" json:"cutstrings"`
|
||||
ExtractRegex []string `yaml:"extract_regex" json:"extract_regex"`
|
||||
TrimSpace []string `yaml:"trimspace" json:"trimspace"`
|
||||
TrimSuffix []string `yaml:"trimsuffix" json:"trimsuffix"`
|
||||
|
||||
ContextSize *int `yaml:"context_size"`
|
||||
NUMA bool `yaml:"numa"`
|
||||
LoraAdapter string `yaml:"lora_adapter"`
|
||||
LoraBase string `yaml:"lora_base"`
|
||||
LoraAdapters []string `yaml:"lora_adapters"`
|
||||
LoraScales []float32 `yaml:"lora_scales"`
|
||||
LoraScale float32 `yaml:"lora_scale"`
|
||||
NoMulMatQ bool `yaml:"no_mulmatq"`
|
||||
DraftModel string `yaml:"draft_model"`
|
||||
NDraft int32 `yaml:"n_draft"`
|
||||
Quantization string `yaml:"quantization"`
|
||||
LoadFormat string `yaml:"load_format"`
|
||||
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
|
||||
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
|
||||
EnforceEager bool `yaml:"enforce_eager"` // vLLM
|
||||
SwapSpace int `yaml:"swap_space"` // vLLM
|
||||
MaxModelLen int `yaml:"max_model_len"` // vLLM
|
||||
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
|
||||
DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM
|
||||
DType string `yaml:"dtype"` // vLLM
|
||||
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM
|
||||
MMProj string `yaml:"mmproj"`
|
||||
ContextSize *int `yaml:"context_size" json:"context_size"`
|
||||
NUMA bool `yaml:"numa" json:"numa"`
|
||||
LoraAdapter string `yaml:"lora_adapter" json:"lora_adapter"`
|
||||
LoraBase string `yaml:"lora_base" json:"lora_base"`
|
||||
LoraAdapters []string `yaml:"lora_adapters" json:"lora_adapters"`
|
||||
LoraScales []float32 `yaml:"lora_scales" json:"lora_scales"`
|
||||
LoraScale float32 `yaml:"lora_scale" json:"lora_scale"`
|
||||
NoMulMatQ bool `yaml:"no_mulmatq" json:"no_mulmatq"`
|
||||
DraftModel string `yaml:"draft_model" json:"draft_model"`
|
||||
NDraft int32 `yaml:"n_draft" json:"n_draft"`
|
||||
Quantization string `yaml:"quantization" json:"quantization"`
|
||||
LoadFormat string `yaml:"load_format" json:"load_format"`
|
||||
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization" json:"gpu_memory_utilization"` // vLLM
|
||||
TrustRemoteCode bool `yaml:"trust_remote_code" json:"trust_remote_code"` // vLLM
|
||||
EnforceEager bool `yaml:"enforce_eager" json:"enforce_eager"` // vLLM
|
||||
SwapSpace int `yaml:"swap_space" json:"swap_space"` // vLLM
|
||||
MaxModelLen int `yaml:"max_model_len" json:"max_model_len"` // vLLM
|
||||
TensorParallelSize int `yaml:"tensor_parallel_size" json:"tensor_parallel_size"` // vLLM
|
||||
DisableLogStatus bool `yaml:"disable_log_stats" json:"disable_log_stats"` // vLLM
|
||||
DType string `yaml:"dtype" json:"dtype"` // vLLM
|
||||
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt" json:"limit_mm_per_prompt"` // vLLM
|
||||
MMProj string `yaml:"mmproj" json:"mmproj"`
|
||||
|
||||
FlashAttention bool `yaml:"flash_attention"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading"`
|
||||
CacheTypeK string `yaml:"cache_type_k"`
|
||||
CacheTypeV string `yaml:"cache_type_v"`
|
||||
FlashAttention bool `yaml:"flash_attention" json:"flash_attention"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
|
||||
CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
|
||||
CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
|
||||
|
||||
RopeScaling string `yaml:"rope_scaling"`
|
||||
ModelType string `yaml:"type"`
|
||||
RopeScaling string `yaml:"rope_scaling" json:"rope_scaling"`
|
||||
ModelType string `yaml:"type" json:"type"`
|
||||
|
||||
YarnExtFactor float32 `yaml:"yarn_ext_factor"`
|
||||
YarnAttnFactor float32 `yaml:"yarn_attn_factor"`
|
||||
YarnBetaFast float32 `yaml:"yarn_beta_fast"`
|
||||
YarnBetaSlow float32 `yaml:"yarn_beta_slow"`
|
||||
YarnExtFactor float32 `yaml:"yarn_ext_factor" json:"yarn_ext_factor"`
|
||||
YarnAttnFactor float32 `yaml:"yarn_attn_factor" json:"yarn_attn_factor"`
|
||||
YarnBetaFast float32 `yaml:"yarn_beta_fast" json:"yarn_beta_fast"`
|
||||
YarnBetaSlow float32 `yaml:"yarn_beta_slow" json:"yarn_beta_slow"`
|
||||
|
||||
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
|
||||
CFGScale float32 `yaml:"cfg_scale" json:"cfg_scale"` // Classifier-Free Guidance Scale
|
||||
}
|
||||
|
||||
// LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
|
||||
type LimitMMPerPrompt struct {
|
||||
LimitImagePerPrompt int `yaml:"image"`
|
||||
LimitVideoPerPrompt int `yaml:"video"`
|
||||
LimitAudioPerPrompt int `yaml:"audio"`
|
||||
LimitImagePerPrompt int `yaml:"image" json:"image"`
|
||||
LimitVideoPerPrompt int `yaml:"video" json:"video"`
|
||||
LimitAudioPerPrompt int `yaml:"audio" json:"audio"`
|
||||
}
|
||||
|
||||
// TemplateConfig is a struct that holds the configuration of the templating system
|
||||
type TemplateConfig struct {
|
||||
// Chat is the template used in the chat completion endpoint
|
||||
Chat string `yaml:"chat"`
|
||||
Chat string `yaml:"chat" json:"chat"`
|
||||
|
||||
// ChatMessage is the template used for chat messages
|
||||
ChatMessage string `yaml:"chat_message"`
|
||||
ChatMessage string `yaml:"chat_message" json:"chat_message"`
|
||||
|
||||
// Completion is the template used for completion requests
|
||||
Completion string `yaml:"completion"`
|
||||
Completion string `yaml:"completion" json:"completion"`
|
||||
|
||||
// Edit is the template used for edit completion requests
|
||||
Edit string `yaml:"edit"`
|
||||
Edit string `yaml:"edit" json:"edit"`
|
||||
|
||||
// Functions is the template used when tools are present in the client requests
|
||||
Functions string `yaml:"function"`
|
||||
Functions string `yaml:"function" json:"function"`
|
||||
|
||||
// UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used.
|
||||
// Note: this is mostly consumed for backends such as vllm and transformers
|
||||
// that can use the tokenizers specified in the JSON config files of the models
|
||||
UseTokenizerTemplate bool `yaml:"use_tokenizer_template"`
|
||||
UseTokenizerTemplate bool `yaml:"use_tokenizer_template" json:"use_tokenizer_template"`
|
||||
|
||||
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
|
||||
// It defaults to \n
|
||||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
|
||||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character" json:"join_chat_messages_by_character"`
|
||||
|
||||
Multimodal string `yaml:"multimodal"`
|
||||
Multimodal string `yaml:"multimodal" json:"multimodal"`
|
||||
|
||||
JinjaTemplate bool `yaml:"jinja_template"`
|
||||
JinjaTemplate bool `yaml:"jinja_template" json:"jinja_template"`
|
||||
|
||||
ReplyPrefix string `yaml:"reply_prefix"`
|
||||
ReplyPrefix string `yaml:"reply_prefix" json:"reply_prefix"`
|
||||
}
|
||||
|
||||
func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
|
||||
@@ -316,17 +316,12 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
|
||||
return fmt.Errorf("failed to verify path %s: %w", galleryFile, err)
|
||||
}
|
||||
|
||||
var filesToRemove []string
|
||||
|
||||
// Delete all the files associated to the model
|
||||
// read the model config
|
||||
galleryconfig, err := ReadConfigFile[ModelConfig](galleryFile)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("failed to read gallery file %s", configFile)
|
||||
}
|
||||
|
||||
var filesToRemove []string
|
||||
|
||||
// Remove additional files
|
||||
if galleryconfig != nil {
|
||||
if err == nil && galleryconfig != nil {
|
||||
for _, f := range galleryconfig.Files {
|
||||
fullPath := filepath.Join(systemState.Model.ModelsPath, f.Filename)
|
||||
if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil {
|
||||
@@ -334,6 +329,8 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
|
||||
}
|
||||
filesToRemove = append(filesToRemove, fullPath)
|
||||
}
|
||||
} else {
|
||||
log.Error().Err(err).Msgf("failed to read gallery file %s", configFile)
|
||||
}
|
||||
|
||||
for _, f := range additionalFiles {
|
||||
@@ -344,7 +341,6 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
|
||||
filesToRemove = append(filesToRemove, fullPath)
|
||||
}
|
||||
|
||||
filesToRemove = append(filesToRemove, configFile)
|
||||
filesToRemove = append(filesToRemove, galleryFile)
|
||||
|
||||
// skip duplicates
|
||||
@@ -353,11 +349,11 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
|
||||
// Removing files
|
||||
for _, f := range filesToRemove {
|
||||
if e := os.Remove(f); e != nil {
|
||||
err = errors.Join(err, fmt.Errorf("failed to remove file %s: %w", f, e))
|
||||
log.Error().Err(e).Msgf("failed to remove file %s", f)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
return os.Remove(configFile)
|
||||
}
|
||||
|
||||
// This is ***NEVER*** going to be perfect or finished.
|
||||
|
||||
228
core/http/endpoints/localai/edit_model.go
Normal file
228
core/http/endpoints/localai/edit_model.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// GetEditModelPage renders the edit model page with current configuration
|
||||
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName := c.Params("name")
|
||||
if modelName == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration not found",
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
}
|
||||
|
||||
configData, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to marshal configuration: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
}
|
||||
|
||||
// Marshal the config to JSON for the template
|
||||
configJSON, err := json.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to marshal configuration: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
}
|
||||
|
||||
// Render the edit page with the current configuration
|
||||
templateData := struct {
|
||||
Title string
|
||||
ModelName string
|
||||
Config *config.ModelConfig
|
||||
ConfigJSON string
|
||||
ConfigYAML string
|
||||
BaseURL string
|
||||
Version string
|
||||
}{
|
||||
Title: "LocalAI - Edit Model " + modelName,
|
||||
ModelName: modelName,
|
||||
Config: &modelConfig,
|
||||
ConfigJSON: string(configJSON),
|
||||
ConfigYAML: string(configData),
|
||||
BaseURL: httpUtils.BaseURL(c),
|
||||
Version: internal.PrintableVersion(),
|
||||
}
|
||||
|
||||
return c.Render("views/model-editor", templateData)
|
||||
}
|
||||
}
|
||||
|
||||
// EditModelEndpoint handles updating existing model configurations
|
||||
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName := c.Params("name")
|
||||
if modelName == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Get the raw body
|
||||
body := c.Body()
|
||||
if len(body) == 0 {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Request body is empty",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Check content type to determine how to parse
|
||||
contentType := string(c.Context().Request.Header.ContentType())
|
||||
var req config.ModelConfig
|
||||
var err error
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// Parse JSON
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
|
||||
// Parse YAML
|
||||
if err := yaml.Unmarshal(body, &req); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
} else {
|
||||
// Try to auto-detect format
|
||||
if strings.TrimSpace(string(body))[0] == '{' {
|
||||
// Looks like JSON
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
} else {
|
||||
// Assume YAML
|
||||
if err := yaml.Unmarshal(body, &req); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Name == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Load the existing configuration
|
||||
configPath := filepath.Join(appConfig.SystemState.Model.ModelsPath, modelName+".yaml")
|
||||
if err := utils.InTrustedRoot(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration not trusted: " + err.Error(),
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
req.SetDefaults()
|
||||
|
||||
// Validate the configuration
|
||||
if !req.Validate() {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Validation failed",
|
||||
Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."},
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Create the YAML file
|
||||
yamlData, err := yaml.Marshal(req)
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to marshal configuration: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
}
|
||||
|
||||
// Write to file
|
||||
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
}
|
||||
|
||||
// Reload configurations
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
}
|
||||
|
||||
// Preload the model
|
||||
if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to preload model: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
}
|
||||
|
||||
// Return success response
|
||||
response := ModelResponse{
|
||||
Success: true,
|
||||
Message: fmt.Sprintf("Model '%s' updated successfully", modelName),
|
||||
Filename: configPath,
|
||||
Config: req,
|
||||
}
|
||||
return c.JSON(response)
|
||||
}
|
||||
}
|
||||
72
core/http/endpoints/localai/edit_model_test.go
Normal file
72
core/http/endpoints/localai/edit_model_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Edit Model test", func() {
|
||||
|
||||
var tempDir string
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "localai-test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
Context("Edit Model endpoint", func() {
|
||||
It("should edit a model", func() {
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(filepath.Join(tempDir)),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
applicationConfig := config.NewApplicationConfig(
|
||||
config.WithSystemState(systemState),
|
||||
)
|
||||
//modelLoader := model.NewModelLoader(systemState, true)
|
||||
modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath)
|
||||
|
||||
// Define Fiber app.
|
||||
app := fiber.New()
|
||||
app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
|
||||
|
||||
requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`)
|
||||
|
||||
req := httptest.NewRequest("PUT", "/import-model", requestBody)
|
||||
resp, err := app.Test(req, 5000)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(ContainSubstring("Model configuration created successfully"))
|
||||
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
|
||||
|
||||
app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig))
|
||||
requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`)
|
||||
|
||||
req = httptest.NewRequest("GET", "/edit-model/foo", requestBody)
|
||||
resp, _ = app.Test(req, 1)
|
||||
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(ContainSubstring(`"model":"foo"`))
|
||||
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
|
||||
})
|
||||
})
|
||||
})
|
||||
148
core/http/endpoints/localai/import_model.go
Normal file
148
core/http/endpoints/localai/import_model.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ImportModelEndpoint handles creating new model configurations
|
||||
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Get the raw body
|
||||
body := c.Body()
|
||||
if len(body) == 0 {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Request body is empty",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Check content type to determine how to parse
|
||||
contentType := string(c.Context().Request.Header.ContentType())
|
||||
var modelConfig config.ModelConfig
|
||||
var err error
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// Parse JSON
|
||||
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
|
||||
// Parse YAML
|
||||
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
} else {
|
||||
// Try to auto-detect format
|
||||
if strings.TrimSpace(string(body))[0] == '{' {
|
||||
// Looks like JSON
|
||||
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
} else {
|
||||
// Assume YAML
|
||||
if err := yaml.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if modelConfig.Name == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
modelConfig.SetDefaults()
|
||||
|
||||
// Validate the configuration
|
||||
if !modelConfig.Validate() {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Invalid configuration",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Create the configuration file
|
||||
configPath := filepath.Join(appConfig.SystemState.Model.ModelsPath, modelConfig.Name+".yaml")
|
||||
if err := utils.InTrustedRoot(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model path not trusted: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Marshal to YAML for storage
|
||||
yamlData, err := yaml.Marshal(&modelConfig)
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to marshal configuration: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
}
|
||||
|
||||
// Write the file
|
||||
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
}
|
||||
// Reload configurations
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
}
|
||||
|
||||
// Preload the model
|
||||
if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to preload model: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
}
|
||||
// Return success response
|
||||
response := ModelResponse{
|
||||
Success: true,
|
||||
Message: "Model configuration created successfully",
|
||||
Filename: filepath.Base(configPath),
|
||||
}
|
||||
return c.JSON(response)
|
||||
}
|
||||
}
|
||||
13
core/http/endpoints/localai/localai_suite_test.go
Normal file
13
core/http/endpoints/localai/localai_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLocalAIEndpoints(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "LocalAI Endpoints test suite")
|
||||
}
|
||||
11
core/http/endpoints/localai/types.go
Normal file
11
core/http/endpoints/localai/types.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package localai
|
||||
|
||||
// ModelResponse represents the common response structure for model operations
|
||||
type ModelResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Config interface{} `json:"config,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Details []string `json:"details,omitempty"`
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
@@ -23,6 +24,17 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
||||
|
||||
// LocalAI API endpoints
|
||||
if !appConfig.DisableGalleryEndpoint {
|
||||
// Import model page
|
||||
router.Get("/import-model", func(c *fiber.Ctx) error {
|
||||
return c.Render("views/model-editor", fiber.Map{
|
||||
"Title": "LocalAI - Import Model",
|
||||
"BaseURL": httpUtils.BaseURL(c),
|
||||
"Version": internal.PrintableVersion(),
|
||||
})
|
||||
})
|
||||
|
||||
// Edit model page
|
||||
router.Get("/models/edit/:name", localai.GetEditModelPage(cl, appConfig))
|
||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService)
|
||||
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||
@@ -42,6 +54,11 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
||||
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
|
||||
router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
|
||||
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
|
||||
// Custom model import endpoint
|
||||
router.Post("/models/import", localai.ImportModelEndpoint(cl, appConfig))
|
||||
|
||||
// Custom model edit endpoint
|
||||
router.Post("/models/edit/:name", localai.EditModelEndpoint(cl, appConfig))
|
||||
}
|
||||
|
||||
router.Post("/v1/detection",
|
||||
|
||||
@@ -30,6 +30,12 @@
|
||||
<span>Gallery</span>
|
||||
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
|
||||
</a>
|
||||
<a href="/import-model"
|
||||
class="group flex items-center bg-green-600 hover:bg-green-700 text-white py-2 px-6 rounded-lg transition duration-300 ease-in-out transform hover:scale-105 hover:shadow-lg">
|
||||
<i class="fas fa-plus mr-2"></i>
|
||||
<span>Import Custom Model</span>
|
||||
<i class="fas fa-arrow-right opacity-0 group-hover:opacity-100 group-hover:translate-x-2 ml-2 transition-all duration-300"></i>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -143,7 +149,11 @@
|
||||
{{ end }}
|
||||
</div>
|
||||
|
||||
<div class="mt-4 flex justify-end">
|
||||
<div class="mt-4 flex justify-end gap-2">
|
||||
<a href="/models/edit/{{.Name}}"
|
||||
class="inline-flex items-center text-xs font-medium text-indigo-400 hover:text-indigo-300 hover:bg-indigo-900/20 rounded-md px-2 py-1 transition-colors duration-200">
|
||||
<i class="fas fa-edit mr-1.5"></i>Edit
|
||||
</a>
|
||||
<button
|
||||
class="inline-flex items-center text-xs font-medium text-red-400 hover:text-red-300 hover:bg-red-900/20 rounded-md px-2 py-1 transition-colors duration-200"
|
||||
data-twe-ripple-init=""
|
||||
@@ -176,6 +186,12 @@
|
||||
No Configuration
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="mt-4 flex justify-end">
|
||||
<span class="inline-flex items-center text-xs font-medium text-gray-400 px-2 py-1">
|
||||
<i class="fas fa-info-circle mr-1.5"></i>Cannot edit (no config)
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
1022
core/http/views/model-editor.html
Normal file
1022
core/http/views/model-editor.html
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user