mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-24 16:51:44 -04:00
feat(stablediffusion-ggml): add support to ref images (flux Kontext) (#5935)
* feat(stablediffusion-ggml): add support to ref images Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add it to the model gallery Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
4438b4361e
commit
3d22bfc27c
@@ -7,7 +7,7 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
|
||||
|
||||
opts := ModelOptions(backendConfig, appConfig)
|
||||
inferenceModel, err := loader.Load(
|
||||
@@ -33,6 +33,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||
Dst: dst,
|
||||
Src: src,
|
||||
EnableParameters: backendConfig.Diffusers.EnableParameters,
|
||||
RefImages: refImages,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -79,49 +79,37 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
// Process input images (for img2img/inpainting)
|
||||
src := ""
|
||||
if input.File != "" {
|
||||
src = processImageFile(input.File, appConfig.GeneratedContentDir)
|
||||
if src != "" {
|
||||
defer os.RemoveAll(src)
|
||||
}
|
||||
}
|
||||
|
||||
fileData := []byte{}
|
||||
var err error
|
||||
// check if input.File is an URL, if so download it and save it
|
||||
// to a temporary file
|
||||
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
||||
out, err := downloadFile(input.File)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed downloading file:%w", err)
|
||||
}
|
||||
defer os.RemoveAll(out)
|
||||
|
||||
fileData, err = os.ReadFile(out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading file:%w", err)
|
||||
}
|
||||
|
||||
} else {
|
||||
// base 64 decode the file and write it somewhere
|
||||
// that we will cleanup
|
||||
fileData, err = base64.StdEncoding.DecodeString(input.File)
|
||||
if err != nil {
|
||||
return err
|
||||
// Process multiple input images
|
||||
var inputImages []string
|
||||
if len(input.Files) > 0 {
|
||||
for _, file := range input.Files {
|
||||
processedFile := processImageFile(file, appConfig.GeneratedContentDir)
|
||||
if processedFile != "" {
|
||||
inputImages = append(inputImages, processedFile)
|
||||
defer os.RemoveAll(processedFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(appConfig.GeneratedContentDir, "b64")
|
||||
if err != nil {
|
||||
return err
|
||||
// Process reference images
|
||||
var refImages []string
|
||||
if len(input.RefImages) > 0 {
|
||||
for _, file := range input.RefImages {
|
||||
processedFile := processImageFile(file, appConfig.GeneratedContentDir)
|
||||
if processedFile != "" {
|
||||
refImages = append(refImages, processedFile)
|
||||
defer os.RemoveAll(processedFile)
|
||||
}
|
||||
}
|
||||
// write the base64 result
|
||||
writer := bufio.NewWriter(outputFile)
|
||||
_, err = writer.Write(fileData)
|
||||
if err != nil {
|
||||
outputFile.Close()
|
||||
return err
|
||||
}
|
||||
outputFile.Close()
|
||||
src = outputFile.Name()
|
||||
defer os.RemoveAll(src)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
@@ -202,7 +190,13 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
|
||||
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
|
||||
// Use the first input image as src if available, otherwise use the original src
|
||||
inputSrc := src
|
||||
if len(inputImages) > 0 {
|
||||
inputSrc = inputImages[0]
|
||||
}
|
||||
|
||||
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, inputSrc, output, ml, *config, appConfig, refImages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -243,3 +237,51 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
// processImageFile handles a single image file (URL or base64) and returns the path to the temporary file
|
||||
func processImageFile(file string, generatedContentDir string) string {
|
||||
fileData := []byte{}
|
||||
var err error
|
||||
|
||||
// check if file is an URL, if so download it and save it to a temporary file
|
||||
if strings.HasPrefix(file, "http://") || strings.HasPrefix(file, "https://") {
|
||||
out, err := downloadFile(file)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed downloading file: %s", file)
|
||||
return ""
|
||||
}
|
||||
defer os.RemoveAll(out)
|
||||
|
||||
fileData, err = os.ReadFile(out)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed reading downloaded file: %s", out)
|
||||
return ""
|
||||
}
|
||||
} else {
|
||||
// base 64 decode the file and write it somewhere that we will cleanup
|
||||
fileData, err = base64.StdEncoding.DecodeString(file)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed decoding base64 file")
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(generatedContentDir, "b64")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed creating temporary file")
|
||||
return ""
|
||||
}
|
||||
|
||||
// write the base64 result
|
||||
writer := bufio.NewWriter(outputFile)
|
||||
_, err = writer.Write(fileData)
|
||||
if err != nil {
|
||||
outputFile.Close()
|
||||
log.Error().Err(err).Msg("Failed writing to temporary file")
|
||||
return ""
|
||||
}
|
||||
outputFile.Close()
|
||||
|
||||
return outputFile.Name()
|
||||
}
|
||||
|
||||
@@ -141,6 +141,10 @@ type OpenAIRequest struct {
|
||||
|
||||
// whisper
|
||||
File string `json:"file" validate:"required"`
|
||||
// Multiple input images for img2img or inpainting
|
||||
Files []string `json:"files,omitempty"`
|
||||
// Reference images for models that support them (e.g., Flux Kontext)
|
||||
RefImages []string `json:"ref_images,omitempty"`
|
||||
//whisper/image
|
||||
ResponseFormat interface{} `json:"response_format,omitempty"`
|
||||
// image
|
||||
|
||||
Reference in New Issue
Block a user