mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
1 Commits
parth/decr
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87cb080a91 |
15
cmd/cmd.go
15
cmd/cmd.go
@@ -123,6 +123,21 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if FROM points to an imagegen model directory
|
||||
for _, mfCmd := range modelfile.Commands {
|
||||
if mfCmd.Name == "model" {
|
||||
// Resolve the path relative to the Modelfile directory
|
||||
fromPath := mfCmd.Args
|
||||
if !filepath.IsAbs(fromPath) {
|
||||
fromPath = filepath.Join(filepath.Dir(filename), fromPath)
|
||||
}
|
||||
if imagegen.IsTensorModelDir(fromPath) {
|
||||
return imagegenclient.CreateModelFromModelfile(args[0], fromPath, modelfile.Commands, p)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
status := "gathering model components"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
|
||||
@@ -17,7 +17,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -28,14 +31,41 @@ import (
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// CreateModel imports a tensor-based model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
//
|
||||
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
||||
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||
return CreateModelFromModelfile(modelName, modelDir, nil, p)
|
||||
}
|
||||
|
||||
// CreateModelFromModelfile imports a tensor-based model using Modelfile commands.
|
||||
// Extracts LICENSE, REQUIRES, and PARAMETER commands from the Modelfile.
|
||||
func CreateModelFromModelfile(modelName, modelDir string, commands []parser.Command, p *progress.Progress) error {
|
||||
if !imagegen.IsTensorModelDir(modelDir) {
|
||||
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
||||
}
|
||||
|
||||
// Extract metadata from Modelfile commands
|
||||
var licenses []string
|
||||
var requires string
|
||||
params := make(map[string]any)
|
||||
|
||||
for _, c := range commands {
|
||||
switch c.Name {
|
||||
case "license":
|
||||
licenses = append(licenses, c.Args)
|
||||
case "requires":
|
||||
requires = c.Args
|
||||
case "model":
|
||||
// skip - already handled by caller
|
||||
default:
|
||||
// Treat as parameter (steps, width, height, seed, etc.)
|
||||
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
|
||||
if err == nil {
|
||||
for k, v := range ps {
|
||||
params[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
status := "importing image generation model"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
@@ -46,8 +76,6 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
@@ -56,15 +84,12 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create tensor layer callback for individual tensors
|
||||
// name is path-style: "component/tensor_name"
|
||||
// Create tensor layer callback
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
@@ -80,24 +105,27 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create a proper config blob with version requirement
|
||||
// Use Modelfile REQUIRES if specified, otherwise use minimum
|
||||
if requires == "" {
|
||||
requires = MinOllamaVersion
|
||||
}
|
||||
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"image"},
|
||||
Requires: MinOllamaVersion,
|
||||
Requires: requires,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
|
||||
// Convert to server.Layer
|
||||
serverLayers := make([]server.Layer, len(layers))
|
||||
for i, l := range layers {
|
||||
serverLayers[i] = server.Layer{
|
||||
@@ -108,10 +136,31 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Add license layers
|
||||
for _, license := range licenses {
|
||||
layer, err := server.NewLayer(strings.NewReader(license), "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create license layer: %w", err)
|
||||
}
|
||||
serverLayers = append(serverLayers, layer)
|
||||
}
|
||||
|
||||
// Add parameters layer
|
||||
if len(params) > 0 {
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal parameters: %w", err)
|
||||
}
|
||||
layer, err := server.NewLayer(bytes.NewReader(paramsJSON), "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create params layer: %w", err)
|
||||
}
|
||||
serverLayers = append(serverLayers, layer)
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
|
||||
// Progress callback
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
status = msg
|
||||
|
||||
35
x/imagegen/client/create_test.go
Normal file
35
x/imagegen/client/create_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/parser"
|
||||
)
|
||||
|
||||
func TestCreateModelFromModelfileExtractsMetadata(t *testing.T) {
|
||||
// Test that the command parsing works correctly
|
||||
commands := []parser.Command{
|
||||
{Name: "model", Args: "./weights/test"},
|
||||
{Name: "license", Args: "Apache-2.0"},
|
||||
{Name: "requires", Args: "0.15.0"},
|
||||
{Name: "num_predict", Args: "12"},
|
||||
{Name: "seed", Args: "42"},
|
||||
}
|
||||
|
||||
// We can't easily test the full function without a real model dir,
|
||||
// but we can verify the commands are valid parser.Command types
|
||||
for _, c := range commands {
|
||||
if c.Name == "" {
|
||||
t.Error("Command name should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinOllamaVersion(t *testing.T) {
|
||||
if MinOllamaVersion == "" {
|
||||
t.Error("MinOllamaVersion should not be empty")
|
||||
}
|
||||
if MinOllamaVersion[0] < '0' || MinOllamaVersion[0] > '9' {
|
||||
t.Errorf("MinOllamaVersion should start with a number, got %q", MinOllamaVersion)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user