Compare commits

...

1 Commits

Author SHA1 Message Date
jmorganca
87cb080a91 support other modelfile commands for image generation models 2026-01-10 12:39:44 -08:00
3 changed files with 113 additions and 14 deletions

View File

@@ -123,6 +123,21 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
// Check if FROM points to an imagegen model directory
for _, mfCmd := range modelfile.Commands {
if mfCmd.Name == "model" {
// Resolve the path relative to the Modelfile directory
fromPath := mfCmd.Args
if !filepath.IsAbs(fromPath) {
fromPath = filepath.Join(filepath.Dir(filename), fromPath)
}
if imagegen.IsTensorModelDir(fromPath) {
return imagegenclient.CreateModelFromModelfile(args[0], fromPath, modelfile.Commands, p)
}
break
}
}
status := "gathering model components"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)

View File

@@ -17,7 +17,10 @@ import (
"encoding/json"
"fmt"
"io"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
@@ -28,14 +31,41 @@ import (
const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
return CreateModelFromModelfile(modelName, modelDir, nil, p)
}
// CreateModelFromModelfile imports a tensor-based model using Modelfile commands.
// Extracts LICENSE, REQUIRES, and PARAMETER commands from the Modelfile.
func CreateModelFromModelfile(modelName, modelDir string, commands []parser.Command, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
// Extract metadata from Modelfile commands
var licenses []string
var requires string
params := make(map[string]any)
for _, c := range commands {
switch c.Name {
case "license":
licenses = append(licenses, c.Args)
case "requires":
requires = c.Args
case "model":
// skip - already handled by caller
default:
// Treat as parameter (steps, width, height, seed, etc.)
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
if err == nil {
for k, v := range ps {
params[k] = v
}
}
}
}
status := "importing image generation model"
spinner := progress.NewSpinner(status)
p.Add("imagegen", spinner)
@@ -46,8 +76,6 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
@@ -56,15 +84,12 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
}, nil
}
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
// Create tensor layer callback
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
@@ -80,24 +105,27 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create a proper config blob with version requirement
// Use Modelfile REQUIRES if specified, otherwise use minimum
if requires == "" {
requires = MinOllamaVersion
}
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"image"},
Requires: MinOllamaVersion,
Requires: requires,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
// Convert to server.Layer
serverLayers := make([]server.Layer, len(layers))
for i, l := range layers {
serverLayers[i] = server.Layer{
@@ -108,10 +136,31 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
}
}
// Add license layers
for _, license := range licenses {
layer, err := server.NewLayer(strings.NewReader(license), "application/vnd.ollama.image.license")
if err != nil {
return fmt.Errorf("failed to create license layer: %w", err)
}
serverLayers = append(serverLayers, layer)
}
// Add parameters layer
if len(params) > 0 {
paramsJSON, err := json.Marshal(params)
if err != nil {
return fmt.Errorf("failed to marshal parameters: %w", err)
}
layer, err := server.NewLayer(bytes.NewReader(paramsJSON), "application/vnd.ollama.image.params")
if err != nil {
return fmt.Errorf("failed to create params layer: %w", err)
}
serverLayers = append(serverLayers, layer)
}
return server.WriteManifest(name, configLayer, serverLayers)
}
// Progress callback
progressFn := func(msg string) {
spinner.Stop()
status = msg

View File

@@ -0,0 +1,35 @@
package client
import (
"testing"
"github.com/ollama/ollama/parser"
)
func TestCreateModelFromModelfileExtractsMetadata(t *testing.T) {
// Test that the command parsing works correctly
commands := []parser.Command{
{Name: "model", Args: "./weights/test"},
{Name: "license", Args: "Apache-2.0"},
{Name: "requires", Args: "0.15.0"},
{Name: "num_predict", Args: "12"},
{Name: "seed", Args: "42"},
}
// We can't easily test the full function without a real model dir,
// but we can verify the commands are valid parser.Command types
for _, c := range commands {
if c.Name == "" {
t.Error("Command name should not be empty")
}
}
}
func TestMinOllamaVersion(t *testing.T) {
if MinOllamaVersion == "" {
t.Error("MinOllamaVersion should not be empty")
}
if MinOllamaVersion[0] < '0' || MinOllamaVersion[0] > '9' {
t.Errorf("MinOllamaVersion should start with a number, got %q", MinOllamaVersion)
}
}