Compare commits

..

2 Commits

Author SHA1 Message Date
jmorganca
5dc20e91d0 simplify 2026-01-19 16:15:15 -08:00
jmorganca
bda8cb7403 cmd: auto-detect image generation models during create
Remove the --experimental flag requirement for creating image generation
models. Now ollama create automatically detects imagegen models by checking
for model_index.json in the model directory and routes them to the direct
creation path.

This simplifies the workflow for creating FLUX and other diffusion models -
users no longer need to pass --experimental.
2026-01-19 15:23:09 -08:00
5 changed files with 45 additions and 252 deletions

View File

@@ -749,7 +749,7 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModelInfo map[string]any `json:"model_info"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`

View File

@@ -101,67 +101,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Check for --experimental flag for safetensors model creation
experimental, _ := cmd.Flags().GetBool("experimental")
if experimental {
// Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) || filename == "" {
// No Modelfile specified or found - use default
reader = strings.NewReader("FROM .\n")
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
reader = f
}
// Parse the Modelfile
modelfile, err := parser.ParseFile(reader)
if err != nil {
return fmt.Errorf("failed to parse Modelfile: %w", err)
}
// Extract FROM path and configuration
var modelDir string
mfConfig := &xcreateclient.ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
}
// Resolve relative paths based on Modelfile location
if !filepath.IsAbs(modelDir) && filename != "" {
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: modelDir,
Quantize: quantize,
Modelfile: mfConfig,
}, p)
}
var reader io.Reader
filename, err := getModelfileName(cmd)
@@ -197,6 +136,28 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
// Check if this is a tensor model (image generation) and handle it directly
quantize, _ := cmd.Flags().GetString("quantize")
modelDir := filepath.Dir(filename)
for _, cmd := range modelfile.Commands {
if cmd.Name == "model" {
if filepath.IsAbs(cmd.Args) {
modelDir = cmd.Args
} else {
modelDir = filepath.Join(filepath.Dir(filename), cmd.Args)
}
break
}
}
if create.IsTensorModelDir(modelDir) {
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: modelDir,
Quantize: quantize,
Modelfile: xcreateclient.ExtractModelfileConfig(modelfile),
}, p)
}
status := "gathering model components"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
@@ -208,7 +169,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner.Stop()
req.Model = modelName
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
}
@@ -1815,22 +1775,15 @@ func NewCLI() *cobra.Command {
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
createCmd := &cobra.Command{
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: func(cmd *cobra.Command, args []string) error {
// Skip server check for experimental mode (writes directly to disk)
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
return nil
}
return checkServerHeartbeat(cmd, args)
},
RunE: CreateHandler,
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: CreateHandler,
}
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
showCmd := &cobra.Command{
Use: "show MODEL",

View File

@@ -1,174 +0,0 @@
//go:build integration
package integration
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
)
func TestImageGeneration(t *testing.T) {
skipUnderMinVRAM(t, 8)
type testCase struct {
imageGenModel string
visionModel string
prompt string
expectedWords []string
}
testCases := []testCase{
{
imageGenModel: "jmorgan/z-image-turbo",
visionModel: "llama3.2-vision",
prompt: "A cartoon style llama flying like a superhero through the air with clouds in the background",
expectedWords: []string{"llama", "flying", "cartoon", "cloud", "sky", "superhero", "air", "animal", "camelid"},
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s->%s", tc.imageGenModel, tc.visionModel), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
client, testEndpoint, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Pull both models
if err := PullIfMissing(ctx, client, tc.imageGenModel); err != nil {
t.Fatalf("failed to pull image gen model: %v", err)
}
if err := PullIfMissing(ctx, client, tc.visionModel); err != nil {
t.Fatalf("failed to pull vision model: %v", err)
}
// Generate the image
t.Logf("Generating image with prompt: %s", tc.prompt)
imageBase64, err := generateImage(ctx, testEndpoint, tc.imageGenModel, tc.prompt)
if err != nil {
if strings.Contains(err.Error(), "image generation not available") {
t.Skip("Target system does not support image generation")
} else if strings.Contains(err.Error(), "executable file not found in") { // Windows pattern, not yet supported
t.Skip("Windows does not support image generation yet")
} else if strings.Contains(err.Error(), "CUDA driver version is insufficient") {
t.Skip("Driver is too old")
} else if strings.Contains(err.Error(), "insufficient memory for image generation") {
t.Skip("insufficient memory for image generation")
} else if strings.Contains(err.Error(), "error while loading shared libraries: libcuda.so.1") { // AMD GPU or CPU
t.Skip("CUDA GPU is not available")
} else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") {
// most likely linux arm - not supported yet
t.Skip("unsupported architecture")
}
t.Fatalf("failed to generate image: %v", err)
}
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
t.Fatalf("failed to decode image: %v", err)
}
t.Logf("Generated image: %d bytes", len(imageData))
// Preload vision model and check GPU loading
err = client.Generate(ctx, &api.GenerateRequest{Model: tc.visionModel}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load vision model: %v", err)
}
// Use vision model to describe the image
chatReq := api.ChatRequest{
Model: tc.visionModel,
Messages: []api.Message{
{
Role: "user",
Content: "Describe this image in detail. What is shown? What style is it? What is the main subject doing?",
Images: []api.ImageData{imageData},
},
},
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}
// Verify the vision model's response contains expected keywords
response := DoChat(ctx, t, client, chatReq, tc.expectedWords, 240*time.Second, 30*time.Second)
if response != nil {
t.Logf("Vision model response: %s", response.Content)
// Additional detailed check for keywords
content := strings.ToLower(response.Content)
foundWords := []string{}
missingWords := []string{}
for _, word := range tc.expectedWords {
if strings.Contains(content, word) {
foundWords = append(foundWords, word)
} else {
missingWords = append(missingWords, word)
}
}
t.Logf("Found keywords: %v", foundWords)
if len(missingWords) > 0 {
t.Logf("Missing keywords (at least one was found so test passed): %v", missingWords)
}
}
})
}
}
// generateImage calls the OpenAI-compatible image generation API and returns the base64 image data
func generateImage(ctx context.Context, endpoint, model, prompt string) (string, error) {
reqBody := imagegenapi.ImageGenerationRequest{
Model: model,
Prompt: prompt,
N: 1,
Size: "512x512",
ResponseFormat: "b64_json",
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("http://%s/v1/images/generations", endpoint)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
var buf bytes.Buffer
buf.ReadFrom(resp.Body)
return "", fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, buf.String())
}
var genResp imagegenapi.ImageGenerationResponse
if err := json.NewDecoder(resp.Body).Decode(&genResp); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}
if len(genResp.Data) == 0 {
return "", fmt.Errorf("no image data in response")
}
return genResp.Data[0].B64JSON, nil
}

View File

@@ -1149,9 +1149,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
Requires: m.Config.Requires,
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
// default we return an empty map.
ModelInfo: make(map[string]any),
}
if m.Config.RemoteHost != "" {

View File

@@ -12,6 +12,7 @@ import (
"fmt"
"io"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
@@ -280,3 +281,19 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
return layers, nil
}
// ExtractModelfileConfig extracts template, system, and license from a parsed Modelfile.
func ExtractModelfileConfig(modelfile *parser.Modelfile) *ModelfileConfig {
mfConfig := &ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
}
}
return mfConfig
}