mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 21:08:16 -05:00
Compare commits
1 Commits
parth/decr
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87cb080a91 |
15
cmd/cmd.go
15
cmd/cmd.go
@@ -123,6 +123,21 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if FROM points to an imagegen model directory
|
||||||
|
for _, mfCmd := range modelfile.Commands {
|
||||||
|
if mfCmd.Name == "model" {
|
||||||
|
// Resolve the path relative to the Modelfile directory
|
||||||
|
fromPath := mfCmd.Args
|
||||||
|
if !filepath.IsAbs(fromPath) {
|
||||||
|
fromPath = filepath.Join(filepath.Dir(filename), fromPath)
|
||||||
|
}
|
||||||
|
if imagegen.IsTensorModelDir(fromPath) {
|
||||||
|
return imagegenclient.CreateModelFromModelfile(args[0], fromPath, modelfile.Commands, p)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
status := "gathering model components"
|
status := "gathering model components"
|
||||||
spinner := progress.NewSpinner(status)
|
spinner := progress.NewSpinner(status)
|
||||||
p.Add(status, spinner)
|
p.Add(status, spinner)
|
||||||
|
|||||||
@@ -17,7 +17,10 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/server"
|
"github.com/ollama/ollama/server"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
@@ -28,14 +31,41 @@ import (
|
|||||||
const MinOllamaVersion = "0.14.0"
|
const MinOllamaVersion = "0.14.0"
|
||||||
|
|
||||||
// CreateModel imports a tensor-based model from a local directory.
|
// CreateModel imports a tensor-based model from a local directory.
|
||||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
|
||||||
//
|
|
||||||
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
|
||||||
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||||
|
return CreateModelFromModelfile(modelName, modelDir, nil, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateModelFromModelfile imports a tensor-based model using Modelfile commands.
|
||||||
|
// Extracts LICENSE, REQUIRES, and PARAMETER commands from the Modelfile.
|
||||||
|
func CreateModelFromModelfile(modelName, modelDir string, commands []parser.Command, p *progress.Progress) error {
|
||||||
if !imagegen.IsTensorModelDir(modelDir) {
|
if !imagegen.IsTensorModelDir(modelDir) {
|
||||||
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract metadata from Modelfile commands
|
||||||
|
var licenses []string
|
||||||
|
var requires string
|
||||||
|
params := make(map[string]any)
|
||||||
|
|
||||||
|
for _, c := range commands {
|
||||||
|
switch c.Name {
|
||||||
|
case "license":
|
||||||
|
licenses = append(licenses, c.Args)
|
||||||
|
case "requires":
|
||||||
|
requires = c.Args
|
||||||
|
case "model":
|
||||||
|
// skip - already handled by caller
|
||||||
|
default:
|
||||||
|
// Treat as parameter (steps, width, height, seed, etc.)
|
||||||
|
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
|
||||||
|
if err == nil {
|
||||||
|
for k, v := range ps {
|
||||||
|
params[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
status := "importing image generation model"
|
status := "importing image generation model"
|
||||||
spinner := progress.NewSpinner(status)
|
spinner := progress.NewSpinner(status)
|
||||||
p.Add("imagegen", spinner)
|
p.Add("imagegen", spinner)
|
||||||
@@ -46,8 +76,6 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return imagegen.LayerInfo{}, err
|
return imagegen.LayerInfo{}, err
|
||||||
}
|
}
|
||||||
layer.Name = name
|
|
||||||
|
|
||||||
return imagegen.LayerInfo{
|
return imagegen.LayerInfo{
|
||||||
Digest: layer.Digest,
|
Digest: layer.Digest,
|
||||||
Size: layer.Size,
|
Size: layer.Size,
|
||||||
@@ -56,15 +84,12 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create tensor layer callback for individual tensors
|
// Create tensor layer callback
|
||||||
// name is path-style: "component/tensor_name"
|
|
||||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
|
||||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return imagegen.LayerInfo{}, err
|
return imagegen.LayerInfo{}, err
|
||||||
}
|
}
|
||||||
layer.Name = name
|
|
||||||
|
|
||||||
return imagegen.LayerInfo{
|
return imagegen.LayerInfo{
|
||||||
Digest: layer.Digest,
|
Digest: layer.Digest,
|
||||||
Size: layer.Size,
|
Size: layer.Size,
|
||||||
@@ -80,24 +105,27 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
|||||||
return fmt.Errorf("invalid model name: %s", modelName)
|
return fmt.Errorf("invalid model name: %s", modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a proper config blob with version requirement
|
// Use Modelfile REQUIRES if specified, otherwise use minimum
|
||||||
|
if requires == "" {
|
||||||
|
requires = MinOllamaVersion
|
||||||
|
}
|
||||||
|
|
||||||
configData := model.ConfigV2{
|
configData := model.ConfigV2{
|
||||||
ModelFormat: "safetensors",
|
ModelFormat: "safetensors",
|
||||||
Capabilities: []string{"image"},
|
Capabilities: []string{"image"},
|
||||||
Requires: MinOllamaVersion,
|
Requires: requires,
|
||||||
}
|
}
|
||||||
configJSON, err := json.Marshal(configData)
|
configJSON, err := json.Marshal(configData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal config: %w", err)
|
return fmt.Errorf("failed to marshal config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create config layer blob
|
|
||||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create config layer: %w", err)
|
return fmt.Errorf("failed to create config layer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
|
// Convert to server.Layer
|
||||||
serverLayers := make([]server.Layer, len(layers))
|
serverLayers := make([]server.Layer, len(layers))
|
||||||
for i, l := range layers {
|
for i, l := range layers {
|
||||||
serverLayers[i] = server.Layer{
|
serverLayers[i] = server.Layer{
|
||||||
@@ -108,10 +136,31 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add license layers
|
||||||
|
for _, license := range licenses {
|
||||||
|
layer, err := server.NewLayer(strings.NewReader(license), "application/vnd.ollama.image.license")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create license layer: %w", err)
|
||||||
|
}
|
||||||
|
serverLayers = append(serverLayers, layer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add parameters layer
|
||||||
|
if len(params) > 0 {
|
||||||
|
paramsJSON, err := json.Marshal(params)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal parameters: %w", err)
|
||||||
|
}
|
||||||
|
layer, err := server.NewLayer(bytes.NewReader(paramsJSON), "application/vnd.ollama.image.params")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create params layer: %w", err)
|
||||||
|
}
|
||||||
|
serverLayers = append(serverLayers, layer)
|
||||||
|
}
|
||||||
|
|
||||||
return server.WriteManifest(name, configLayer, serverLayers)
|
return server.WriteManifest(name, configLayer, serverLayers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Progress callback
|
|
||||||
progressFn := func(msg string) {
|
progressFn := func(msg string) {
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
status = msg
|
status = msg
|
||||||
|
|||||||
35
x/imagegen/client/create_test.go
Normal file
35
x/imagegen/client/create_test.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/parser"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateModelFromModelfileExtractsMetadata(t *testing.T) {
|
||||||
|
// Test that the command parsing works correctly
|
||||||
|
commands := []parser.Command{
|
||||||
|
{Name: "model", Args: "./weights/test"},
|
||||||
|
{Name: "license", Args: "Apache-2.0"},
|
||||||
|
{Name: "requires", Args: "0.15.0"},
|
||||||
|
{Name: "num_predict", Args: "12"},
|
||||||
|
{Name: "seed", Args: "42"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// We can't easily test the full function without a real model dir,
|
||||||
|
// but we can verify the commands are valid parser.Command types
|
||||||
|
for _, c := range commands {
|
||||||
|
if c.Name == "" {
|
||||||
|
t.Error("Command name should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMinOllamaVersion(t *testing.T) {
|
||||||
|
if MinOllamaVersion == "" {
|
||||||
|
t.Error("MinOllamaVersion should not be empty")
|
||||||
|
}
|
||||||
|
if MinOllamaVersion[0] < '0' || MinOllamaVersion[0] > '9' {
|
||||||
|
t.Errorf("MinOllamaVersion should start with a number, got %q", MinOllamaVersion)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user