mirror of
https://github.com/ollama/ollama.git
synced 2026-01-21 22:10:58 -05:00
Compare commits
9 Commits
imagegen-a
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9dac1aaef | ||
|
|
b5d0f72f16 | ||
|
|
148a1be0a3 | ||
|
|
d6dd430abd | ||
|
|
ae78112c50 | ||
|
|
01cf7445f3 | ||
|
|
31085d5e53 | ||
|
|
c42e9d244f | ||
|
|
e98b5e8b4e |
@@ -749,7 +749,7 @@ type ShowResponse struct {
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
|
||||
@@ -313,6 +313,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &deepseek2Model{}
|
||||
case "Glm4MoeLiteForCausalLM":
|
||||
conv = &glm4MoeLiteModel{}
|
||||
case "Lfm2ForCausalLM":
|
||||
conv = &lfm2Model{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
100
convert/convert_lfm2.go
Normal file
100
convert/convert_lfm2.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type lfm2Model struct {
|
||||
ModelParameters
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NormEps float32 `json:"norm_eps"`
|
||||
ConvLCache uint32 `json:"conv_L_cache"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
TieEmbedding bool `json:"tie_embedding"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*lfm2Model)(nil)
|
||||
|
||||
func (p *lfm2Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "lfm2"
|
||||
kv["lfm2.vocab_size"] = p.VocabSize
|
||||
kv["lfm2.block_count"] = p.NumHiddenLayers
|
||||
kv["lfm2.embedding_length"] = p.HiddenSize
|
||||
kv["lfm2.feed_forward_length"] = p.IntermediateSize
|
||||
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
// Build per-layer KV head count array based on layer_types
|
||||
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
|
||||
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
|
||||
for i := range p.NumHiddenLayers {
|
||||
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
|
||||
kvHeadCounts[i] = p.NumKeyValueHeads
|
||||
}
|
||||
}
|
||||
|
||||
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
|
||||
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
|
||||
kv["lfm2.rope.freq_base"] = p.RopeTheta
|
||||
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
shape := t.Shape()
|
||||
|
||||
// Squeeze conv weights: [D, 1, K] -> [D, K]
|
||||
if strings.HasSuffix(t.Name(), "shortconv.conv.weight") {
|
||||
if len(shape) == 3 && shape[1] == 1 {
|
||||
shape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *lfm2Model) Replacements() []string {
|
||||
return []string{
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.embedding_norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
"operator_norm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.out_proj", "attn_output",
|
||||
"self_attn.q_layernorm", "attn_q_norm",
|
||||
"self_attn.k_layernorm", "attn_k_norm",
|
||||
"conv.conv", "shortconv.conv",
|
||||
"conv.in_proj", "shortconv.in_proj",
|
||||
"conv.out_proj", "shortconv.out_proj",
|
||||
"feed_forward.w1", "ffn_gate",
|
||||
"feed_forward.w2", "ffn_down",
|
||||
"feed_forward.w3", "ffn_up",
|
||||
"ffn_norm", "ffn_norm",
|
||||
}
|
||||
}
|
||||
@@ -40,6 +40,7 @@ const (
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
|
||||
43
docs/integrations/index.mdx
Normal file
43
docs/integrations/index.mdx
Normal file
@@ -0,0 +1,43 @@
|
||||
---
|
||||
title: Integrations
|
||||
---
|
||||
|
||||
Ollama integrates with a wide range of tools, including:
|
||||
|
||||
## Coding Agents
|
||||
|
||||
Coding assistants that can read, modify, and execute code in your projects.
|
||||
|
||||
- [Claude Code](/integrations/claude-code)
|
||||
- [Codex](/integrations/codex)
|
||||
- [Droid](/integrations/droid)
|
||||
- [Goose](/integrations/goose)
|
||||
|
||||
## IDEs & Editors
|
||||
|
||||
Native integrations for popular development environments.
|
||||
|
||||
- [VS Code](/integrations/vscode)
|
||||
- [Cline](/integrations/cline)
|
||||
- [Roo Code](/integrations/roo-code)
|
||||
- [JetBrains](/integrations/jetbrains)
|
||||
- [Xcode](/integrations/xcode)
|
||||
- [Zed](/integrations/zed)
|
||||
|
||||
## Chat & RAG
|
||||
|
||||
Chat interfaces and retrieval-augmented generation platforms.
|
||||
|
||||
- [Onyx](/integrations/onyx)
|
||||
|
||||
## Automation
|
||||
|
||||
Workflow automation platforms with AI integration.
|
||||
|
||||
- [n8n](/integrations/n8n)
|
||||
|
||||
## Notebooks
|
||||
|
||||
Interactive computing environments with AI capabilities.
|
||||
|
||||
- [marimo](/integrations/marimo)
|
||||
@@ -270,6 +270,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
"lfm2",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -859,6 +860,7 @@ func (f GGML) FlashAttention() bool {
|
||||
"gemma3",
|
||||
"glm4moelite",
|
||||
"gptoss", "gpt-oss",
|
||||
"lfm2",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
|
||||
148
integration/imagegen_test.go
Normal file
148
integration/imagegen_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestImageGeneration(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 8)
|
||||
|
||||
type testCase struct {
|
||||
imageGenModel string
|
||||
visionModel string
|
||||
prompt string
|
||||
expectedWords []string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
imageGenModel: "jmorgan/z-image-turbo",
|
||||
visionModel: "llama3.2-vision",
|
||||
prompt: "A cartoon style llama flying like a superhero through the air with clouds in the background",
|
||||
expectedWords: []string{"llama", "flying", "cartoon", "cloud", "sky", "superhero", "air", "animal", "camelid"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%s->%s", tc.imageGenModel, tc.visionModel), func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Pull both models
|
||||
if err := PullIfMissing(ctx, client, tc.imageGenModel); err != nil {
|
||||
t.Fatalf("failed to pull image gen model: %v", err)
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, tc.visionModel); err != nil {
|
||||
t.Fatalf("failed to pull vision model: %v", err)
|
||||
}
|
||||
|
||||
// Generate the image
|
||||
t.Logf("Generating image with prompt: %s", tc.prompt)
|
||||
imageBase64, err := generateImage(ctx, client, tc.imageGenModel, tc.prompt)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "image generation not available") {
|
||||
t.Skip("Target system does not support image generation")
|
||||
} else if strings.Contains(err.Error(), "executable file not found in") { // Windows pattern, not yet supported
|
||||
t.Skip("Windows does not support image generation yet")
|
||||
} else if strings.Contains(err.Error(), "CUDA driver version is insufficient") {
|
||||
t.Skip("Driver is too old")
|
||||
} else if strings.Contains(err.Error(), "insufficient memory for image generation") {
|
||||
t.Skip("insufficient memory for image generation")
|
||||
} else if strings.Contains(err.Error(), "error while loading shared libraries: libcuda.so.1") { // AMD GPU or CPU
|
||||
t.Skip("CUDA GPU is not available")
|
||||
} else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") {
|
||||
// most likely linux arm - not supported yet
|
||||
t.Skip("unsupported architecture")
|
||||
}
|
||||
t.Fatalf("failed to generate image: %v", err)
|
||||
}
|
||||
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode image: %v", err)
|
||||
}
|
||||
t.Logf("Generated image: %d bytes", len(imageData))
|
||||
|
||||
// Preload vision model and check GPU loading
|
||||
err = client.Generate(ctx, &api.GenerateRequest{Model: tc.visionModel}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load vision model: %v", err)
|
||||
}
|
||||
|
||||
// Use vision model to describe the image
|
||||
chatReq := api.ChatRequest{
|
||||
Model: tc.visionModel,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Describe this image in detail. What is shown? What style is it? What is the main subject doing?",
|
||||
Images: []api.ImageData{imageData},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
// Verify the vision model's response contains expected keywords
|
||||
response := DoChat(ctx, t, client, chatReq, tc.expectedWords, 240*time.Second, 30*time.Second)
|
||||
if response != nil {
|
||||
t.Logf("Vision model response: %s", response.Content)
|
||||
|
||||
// Additional detailed check for keywords
|
||||
content := strings.ToLower(response.Content)
|
||||
foundWords := []string{}
|
||||
missingWords := []string{}
|
||||
for _, word := range tc.expectedWords {
|
||||
if strings.Contains(content, word) {
|
||||
foundWords = append(foundWords, word)
|
||||
} else {
|
||||
missingWords = append(missingWords, word)
|
||||
}
|
||||
}
|
||||
t.Logf("Found keywords: %v", foundWords)
|
||||
if len(missingWords) > 0 {
|
||||
t.Logf("Missing keywords (at least one was found so test passed): %v", missingWords)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateImage calls the Ollama API to generate an image and returns the base64 image data
|
||||
func generateImage(ctx context.Context, client *api.Client, model, prompt string) (string, error) {
|
||||
var imageBase64 string
|
||||
|
||||
err := client.Generate(ctx, &api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
}, func(resp api.GenerateResponse) error {
|
||||
if resp.Image != "" {
|
||||
imageBase64 = resp.Image
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate image: %w", err)
|
||||
}
|
||||
|
||||
if imageBase64 == "" {
|
||||
return "", fmt.Errorf("no image data in response")
|
||||
}
|
||||
|
||||
return imageBase64, nil
|
||||
}
|
||||
@@ -38,6 +38,7 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"lfm2.5-thinking",
|
||||
"ministral-3",
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
@@ -143,6 +144,7 @@ var (
|
||||
"granite3.3",
|
||||
"hermes3",
|
||||
"internlm2",
|
||||
"lfm2.5-thinking",
|
||||
"llama-guard3",
|
||||
"llama-pro",
|
||||
"llama2-chinese",
|
||||
@@ -263,6 +265,7 @@ var (
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
libraryToolsModels = []string{
|
||||
"lfm2.5-thinking",
|
||||
"qwen3-vl",
|
||||
"gpt-oss:20b",
|
||||
"gpt-oss:120b",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
@@ -14,7 +14,7 @@ type Layer struct {
|
||||
Size int64 `json:"size"`
|
||||
From string `json:"from,omitempty"`
|
||||
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
|
||||
status string
|
||||
Status string `json:"-"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -22,7 +22,7 @@ const (
|
||||
)
|
||||
|
||||
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
blobs, err := GetBlobsPath("")
|
||||
blobs, err := BlobsPath("")
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -45,7 +45,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
}
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
||||
blob, err := GetBlobsPath(digest)
|
||||
blob, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
MediaType: mediatype,
|
||||
Digest: digest,
|
||||
Size: n,
|
||||
status: fmt.Sprintf("%s %s", status, digest),
|
||||
Status: fmt.Sprintf("%s %s", status, digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
return Layer{}, errors.New("creating new layer from layer with empty digest")
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(digest)
|
||||
blob, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
@@ -89,7 +89,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
Digest: digest,
|
||||
Size: fi.Size(),
|
||||
From: from,
|
||||
status: fmt.Sprintf("using existing layer %s", digest),
|
||||
Status: fmt.Sprintf("using existing layer %s", digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
||||
return nil, errors.New("opening layer with empty digest")
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
blob, err := BlobsPath(l.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -126,7 +126,7 @@ func (l *Layer) Remove() error {
|
||||
}
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
blob, err := BlobsPath(l.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -33,12 +32,38 @@ func (m *Manifest) Size() (size int64) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Manifest) Digest() string {
|
||||
return m.digest
|
||||
}
|
||||
|
||||
func (m *Manifest) FileInfo() os.FileInfo {
|
||||
return m.fi
|
||||
}
|
||||
|
||||
// ReadConfigJSON reads and unmarshals a config layer as JSON.
|
||||
func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
|
||||
for _, layer := range m.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
|
||||
blobPath, err := BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("config %q not found in manifest", configPath)
|
||||
}
|
||||
|
||||
func (m *Manifest) Remove() error {
|
||||
if err := os.Remove(m.filepath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -70,11 +95,11 @@ func (m *Manifest) RemoveLayers() error {
|
||||
if _, used := inUse[layer.Digest]; used {
|
||||
continue
|
||||
}
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
blob, err := BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.Remove(blob); os.IsNotExist(err) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
@@ -89,7 +114,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
return nil, model.Unqualified(n)
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -121,7 +146,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
}
|
||||
|
||||
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -148,7 +173,7 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||
}
|
||||
|
||||
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
95
manifest/paths.go
Normal file
95
manifest/paths.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
|
||||
func Path() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PathForName returns the path to the manifest file for a specific model name.
|
||||
func PathForName(n model.Name) (string, error) {
|
||||
if !n.IsValid() {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(manifests, n.Filepath()), nil
|
||||
}
|
||||
|
||||
func BlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PruneDirectory removes empty directories recursively.
|
||||
func PruneDirectory(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
entries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -162,6 +162,7 @@ type Tensor interface {
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
|
||||
SSMConv(ctx Context, kernel Tensor) Tensor
|
||||
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
|
||||
@@ -1641,6 +1641,13 @@ func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2,
|
||||
return tt
|
||||
}
|
||||
|
||||
func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
||||
410
model/models/lfm2/cache.go
Normal file
410
model/models/lfm2/cache.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
||||
|
||||
// HybridCache stores:
|
||||
// - a standard causal KV cache for attention layers
|
||||
// - a per-sequence recurrent conv state for shortconv layers
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
|
||||
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
|
||||
type HybridCache struct {
|
||||
kv *kvcache.Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
hiddenSize int
|
||||
dConv int
|
||||
|
||||
// slot mapping for recurrent state
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
// track any error from EnsureWritable to propagate later
|
||||
writableError error
|
||||
}
|
||||
|
||||
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
|
||||
return &HybridCache{
|
||||
kv: kvcache.NewCausalCache(shift),
|
||||
hiddenSize: hiddenSize,
|
||||
dConv: dConv,
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for shortconv.
|
||||
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
|
||||
seqCounts := make(map[int]int)
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if _, ok := seqCounts[s]; !ok {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
// without modifying permanent state (slotForSeq, refCount)
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
slots := make([]int32, nSeqs)
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
slots[i] = int32(i)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int // track newly allocated slots that need zeroing
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
// Zero conv state for newly allocated slots to clear stale data from previous sequences
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroConvSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
// Create a tensor for the current slots
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
// Reset writable state for new forward pass
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, kvcache.ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) freeSlot(slot int) {
|
||||
// Bounds check before freeing
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroConvSlots zeros the conv state for the given slots across all layers.
|
||||
// This must be called when recycling slots to prevent stale state from affecting new sequences.
|
||||
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 || len(c.convStates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Use input context for creating tensors
|
||||
inputCtx := ctx.Input()
|
||||
|
||||
// Create slot indices tensor
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, s := range slots {
|
||||
slotIndices[i] = int32(s)
|
||||
}
|
||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
||||
|
||||
// Create zero tensor for the slots (SetRows requires F32 source)
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
|
||||
|
||||
// Zero each layer's conv state for these slots
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
|
||||
// Returns an error if slot allocation fails.
|
||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
// Copy existing conv state for all initialized layers
|
||||
for _, buf := range c.convStates {
|
||||
// buf: [dConv*hiddenSize, maxSlots]
|
||||
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild current slots tensor
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
|
||||
// On the first write to dst, EnsureWritable will create a private slot.
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
// Bounds check before decrementing
|
||||
if dstSlot >= 0 && dstSlot < len(c.refCount) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
// src may not have a slot yet; dst will allocate on demand
|
||||
return
|
||||
}
|
||||
|
||||
// Bounds check before incrementing
|
||||
if srcSlot >= 0 && srcSlot < len(c.refCount) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
||||
return c.kv.CanResume(seq, pos)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For recurrent state, any removal invalidates the state because
|
||||
// the state at position N depends on all previous positions.
|
||||
// Drop the slot mapping so it resets on next use.
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
|
||||
// Returns an error if copy-on-write allocation fails.
|
||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes a new conv state for current batch sequences.
|
||||
// newState must have shape [dConv, hiddenSize, nSeqs].
|
||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *HybridCache) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
}
|
||||
444
model/models/lfm2/cache_test.go
Normal file
444
model/models/lfm2/cache_test.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// TestHybridCache tests verify the slot management logic of HybridCache.
|
||||
// These tests focus on the recurrent state slot allocation, reference counting,
|
||||
// and copy-on-write semantics without requiring a full ML backend.
|
||||
|
||||
// createSlotOnlyCache creates a HybridCache with only the slot management
|
||||
// fields initialized. Used to test slot logic in isolation.
|
||||
func createSlotOnlyCache(maxSequences int) *HybridCache {
|
||||
return &HybridCache{
|
||||
hiddenSize: 256,
|
||||
dConv: 3,
|
||||
maxSequences: maxSequences,
|
||||
refCount: make([]int, maxSequences),
|
||||
freeSlots: initFreeSlots(maxSequences),
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func initFreeSlots(n int) []int {
|
||||
slots := make([]int, 0, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
slots = append(slots, i)
|
||||
}
|
||||
return slots
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotAllocation(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Verify initial state
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate all slots
|
||||
for range 4 {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.refCount[slot] = 1
|
||||
}
|
||||
|
||||
// Should be full now
|
||||
if len(cache.freeSlots) != 0 {
|
||||
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Trying to allocate another should fail
|
||||
_, err := cache.allocSlot()
|
||||
if err != kvcache.ErrKvCacheFull {
|
||||
t.Errorf("expected ErrKvCacheFull, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotReuse(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate a slot
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free it
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
|
||||
// Allocate again - should get the same slot back (LIFO)
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Simulate sharing slot with seq 2 (copy-on-write style)
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Should share the same slot
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
}
|
||||
|
||||
// Ref count should be 2
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share with seq 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Unshare seq 2
|
||||
cache.refCount[slot1]--
|
||||
delete(cache.slotForSeq, 2)
|
||||
|
||||
// Ref count should be back to 1
|
||||
if cache.refCount[slot1] != 1 {
|
||||
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Seq 2 should no longer have a slot
|
||||
if _, ok := cache.slotForSeq[2]; ok {
|
||||
t.Error("seq 2 should not have a slot after unshare")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot when refCount drops to 0
|
||||
cache.refCount[slot1]--
|
||||
if cache.refCount[slot1] <= 0 {
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
}
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Slot should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Ref count should be 0
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotOverwrite(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slots for seq 1 and seq 2
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
slot2, _ := cache.allocSlot()
|
||||
cache.slotForSeq[2] = slot2
|
||||
cache.refCount[slot2] = 1
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Simulate overwriting seq 2's slot with slot1 (sharing)
|
||||
// First free the old slot
|
||||
cache.refCount[slot2]--
|
||||
if cache.refCount[slot2] <= 0 {
|
||||
cache.refCount[slot2] = 0
|
||||
cache.freeSlot(slot2)
|
||||
}
|
||||
// Then share slot1
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Seq 2 should now share slot1
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
}
|
||||
|
||||
// Old slot2 should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots+1 {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_BoundsChecking(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Test freeing invalid slot (should not panic)
|
||||
cache.freeSlot(-1)
|
||||
cache.freeSlot(100) // out of bounds
|
||||
|
||||
// freeSlot does bounds checking, so invalid slots should be ignored
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Fork to seq 2, 3, 4 (all share slot1)
|
||||
for _, seq := range []int{2, 3, 4} {
|
||||
cache.slotForSeq[seq] = slot1
|
||||
cache.refCount[slot1]++
|
||||
}
|
||||
|
||||
// Ref count should be 4
|
||||
if cache.refCount[slot1] != 4 {
|
||||
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Remove seq 2, 3
|
||||
for _, seq := range []int{2, 3} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Slot should still be allocated (not in free list)
|
||||
found := false
|
||||
for _, s := range cache.freeSlots {
|
||||
if s == slot1 {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
t.Error("slot1 should not be in free list yet")
|
||||
}
|
||||
|
||||
// Remove remaining sequences
|
||||
for _, seq := range []int{1, 4} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ChainedSharing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Create seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share 1 -> 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Share 2 -> 3 (should still share slot1)
|
||||
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// All should share slot1
|
||||
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
|
||||
t.Error("all sequences should share slot1")
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 3 {
|
||||
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_CacheParameters(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
|
||||
|
||||
if cache.hiddenSize != 512 {
|
||||
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
|
||||
}
|
||||
if cache.dConv != 5 {
|
||||
t.Errorf("expected dConv 5, got %d", cache.dConv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_NumSeqs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially no sequences
|
||||
if cache.numSeqs() != 0 {
|
||||
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
if cache.numSeqs() != 3 {
|
||||
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SeqTokens(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially 0
|
||||
if cache.seqTokens() != 0 {
|
||||
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqTokens = 16
|
||||
|
||||
if cache.seqTokens() != 16 {
|
||||
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
}
|
||||
|
||||
// Test that Seqs returns a clone of curSeqs
|
||||
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
seqs := cache.Seqs()
|
||||
|
||||
// Modify returned slice
|
||||
seqs[0] = 999
|
||||
|
||||
// Original should be unchanged
|
||||
if cache.curSeqs[0] != 1 {
|
||||
t.Error("Seqs should return a clone, not the original slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially not supported (no batch set up)
|
||||
if cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be false initially")
|
||||
}
|
||||
|
||||
// Set up a valid batch
|
||||
cache.curSeqTokens = 1
|
||||
cache.curSeqs = []int{1}
|
||||
|
||||
if !cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be true with valid batch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// zeroConvSlots should handle empty slots without panicking
|
||||
cache.zeroConvSlots(nil, nil)
|
||||
cache.zeroConvSlots(nil, []int{})
|
||||
|
||||
// zeroConvSlots should handle empty convStates without panicking
|
||||
cache.zeroConvSlots(nil, []int{0, 1, 2})
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot (simulating sequence removal)
|
||||
cache.refCount[slot1]--
|
||||
cache.freeSlot(slot1)
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Verify slot is in free list
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate for new seq 2 - should get recycled slot
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
|
||||
}
|
||||
|
||||
// This recycled slot would need zeroing in the real implementation
|
||||
// The actual zeroing is tested via integration tests since it requires ML context
|
||||
}
|
||||
|
||||
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Simulate the slot allocation flow from StartForward
|
||||
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
|
||||
|
||||
newSlots := []int{}
|
||||
|
||||
// Seq 1 doesn't have a slot - allocate and track
|
||||
seq := 1
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
|
||||
// Verify newSlots contains the allocated slot
|
||||
if len(newSlots) != 1 {
|
||||
t.Errorf("expected 1 new slot, got %d", len(newSlots))
|
||||
}
|
||||
|
||||
// Seq 1 already has a slot - should NOT be tracked as new
|
||||
newSlots2 := []int{}
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, _ := cache.allocSlot()
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots2 = append(newSlots2, slot)
|
||||
}
|
||||
|
||||
// Verify no new slots for existing sequence
|
||||
if len(newSlots2) != 0 {
|
||||
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
|
||||
}
|
||||
}
|
||||
253
model/models/lfm2/model.go
Normal file
253
model/models/lfm2/model.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
hiddenSize int
|
||||
headDim, ropeDim int
|
||||
|
||||
eps, ropeBase, ropeScale float32
|
||||
|
||||
ropeType string
|
||||
originalContextLength int
|
||||
|
||||
// per-layer head counts (LFM2 alternates attention and recurrent layers)
|
||||
numHeadsByLayer []int
|
||||
numKVHeadsByLayer []int
|
||||
}
|
||||
|
||||
func (o Options) headDimValue() int {
|
||||
// Head dim is shared across layers; fall back to first attention layer head count.
|
||||
for _, h := range o.numHeadsByLayer {
|
||||
if h > 0 {
|
||||
return cmp.Or(o.headDim, o.hiddenSize/h)
|
||||
}
|
||||
}
|
||||
return cmp.Or(o.headDim, o.hiddenSize)
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
opts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
rope.WithAttentionFactor(attnFactor),
|
||||
)
|
||||
}
|
||||
|
||||
headCount := 1
|
||||
for _, h := range o.numHeadsByLayer {
|
||||
if h > 0 {
|
||||
headCount = h
|
||||
break
|
||||
}
|
||||
}
|
||||
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/headCount), o.ropeBase, 1./o.ropeScale, opts...)
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
if c.Uint("expert_count") > 0 {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
if c.String("tokenizer.ggml.model") != "gpt2" {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
var pretokenizers []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
case "default":
|
||||
// use default BPE pretokenizer
|
||||
default:
|
||||
// llama-bpe style (default for LFM2)
|
||||
pretokenizers = []string{
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
}
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
},
|
||||
}
|
||||
|
||||
type headCounts interface {
|
||||
HeadCount() []uint64
|
||||
HeadCountKV() []uint64
|
||||
}
|
||||
hc, ok := c.(headCounts)
|
||||
if !ok {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
headCount := hc.HeadCount()
|
||||
headCountKV := hc.HeadCountKV()
|
||||
|
||||
m.numHeadsByLayer = make([]int, len(m.Layers))
|
||||
m.numKVHeadsByLayer = make([]int, len(m.Layers))
|
||||
for i := range m.Layers {
|
||||
m.numHeadsByLayer[i] = int(headCount[i])
|
||||
m.numKVHeadsByLayer[i] = int(headCountKV[i])
|
||||
|
||||
if m.numKVHeadsByLayer[i] == 0 {
|
||||
m.Layers[i].Operator = &ShortConv{}
|
||||
} else {
|
||||
m.Layers[i].Operator = &Attention{}
|
||||
}
|
||||
}
|
||||
|
||||
lCache := int(c.Uint("shortconv.l_cache"))
|
||||
dConv := max(0, lCache-1)
|
||||
m.Cache = NewHybridCache(m.Shift, m.hiddenSize, dConv)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type Operator interface {
|
||||
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
|
||||
}
|
||||
|
||||
func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
headDim := opts.headDimValue()
|
||||
numHeads := opts.numHeadsByLayer[layer]
|
||||
numKVHeads := opts.numKVHeadsByLayer[layer]
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
query = query.Reshape(ctx, headDim, numHeads, batchSize)
|
||||
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Operator Operator
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.Operator.Forward(ctx, hiddenState, positions, cache, layer, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, outputs, m.Cache.(*HybridCache), &m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("lfm2", New)
|
||||
}
|
||||
50
model/models/lfm2/shortconv.go
Normal file
50
model/models/lfm2/shortconv.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
type shortConvKernel struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
// ShortConv implements the LFM2 short-convolution block (GGML_OP_SSM_CONV) with a recurrent
|
||||
// state stored in the HybridCache.
|
||||
type ShortConv struct {
|
||||
Conv *shortConvKernel `gguf:"shortconv.conv"`
|
||||
InProj *nn.Linear `gguf:"shortconv.in_proj"`
|
||||
OutProj *nn.Linear `gguf:"shortconv.out_proj"`
|
||||
}
|
||||
|
||||
func (sc *ShortConv) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
|
||||
nSeqs := cache.numSeqs()
|
||||
seqTokens := cache.seqTokens()
|
||||
hiddenSize := hiddenStates.Dim(0)
|
||||
if nSeqs <= 0 || seqTokens <= 0 || hiddenStates.Dim(1) != nSeqs*seqTokens {
|
||||
panic("lfm2: unsupported batch layout for shortconv")
|
||||
}
|
||||
|
||||
bcx := sc.InProj.Forward(ctx, hiddenStates).Reshape(ctx, 3*hiddenSize, seqTokens, nSeqs)
|
||||
|
||||
elementSize := bcx.Stride(0)
|
||||
b := bcx.View(ctx, 0*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
|
||||
c := bcx.View(ctx, 1*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
|
||||
x := bcx.View(ctx, 2*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
|
||||
|
||||
bx := b.Mul(ctx, x).Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
state, err := cache.ConvState(ctx, layer)
|
||||
if err != nil {
|
||||
panic("lfm2: failed to get conv state: " + err.Error())
|
||||
}
|
||||
sx := state.Concat(ctx, bx, 0)
|
||||
|
||||
convOut := sx.SSMConv(ctx, sc.Conv.Weight)
|
||||
y := c.Mul(ctx, convOut)
|
||||
|
||||
dConv := sx.Dim(0) - seqTokens
|
||||
cache.UpdateConvState(ctx, layer, sx.Slice(ctx, 0, sx.Dim(0)-dConv, sx.Dim(0), 1))
|
||||
|
||||
return sc.OutProj.Forward(ctx, y.Reshape(ctx, hiddenSize, seqTokens*nSeqs))
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
_ "github.com/ollama/ollama/model/models/lfm2"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
|
||||
498
model/parsers/lfm2.go
Normal file
498
model/parsers/lfm2.go
Normal file
@@ -0,0 +1,498 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type LFM2ParserState int
|
||||
|
||||
const (
|
||||
LFM2CollectingThinking LFM2ParserState = iota
|
||||
LFM2CollectingContent
|
||||
LFM2CollectingToolCalls
|
||||
)
|
||||
|
||||
const (
|
||||
lfm2ThinkingOpenTag = "<think>"
|
||||
lfm2ThinkingCloseTag = "</think>"
|
||||
lfm2ToolCallStartTag = "<|tool_call_start|>"
|
||||
lfm2ToolCallEndTag = "<|tool_call_end|>"
|
||||
)
|
||||
|
||||
type LFM2Parser struct {
|
||||
state LFM2ParserState
|
||||
buffer strings.Builder
|
||||
hasThinkingSupport bool
|
||||
needsThinkingLeadingTrim bool // trim leading whitespace after <think> tag
|
||||
needsContentLeadingTrim bool // trim leading whitespace after </think> tag
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.ThinkValue) {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
// Check both model capability AND request preference
|
||||
thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
|
||||
|
||||
if !thinkingEnabled {
|
||||
p.state = LFM2CollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = LFM2CollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
p.state = LFM2CollectingThinking
|
||||
p.needsThinkingLeadingTrim = true
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.setInitialState(lastMessage, thinkValue)
|
||||
return tools
|
||||
}
|
||||
|
||||
type lfm2Event interface {
|
||||
isLFM2Event()
|
||||
}
|
||||
|
||||
type lfm2EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type lfm2EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type lfm2EventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (lfm2EventThinkingContent) isLFM2Event() {}
|
||||
func (lfm2EventContent) isLFM2Event() {}
|
||||
func (lfm2EventToolCall) isLFM2Event() {}
|
||||
|
||||
func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case lfm2EventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case lfm2EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case lfm2EventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) parseEvents() []lfm2Event {
|
||||
var all []lfm2Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []lfm2Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
|
||||
var events []lfm2Event
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return events, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case LFM2CollectingThinking:
|
||||
// Strip opening <think> tag if present
|
||||
if strings.HasPrefix(bufStr, lfm2ThinkingOpenTag) {
|
||||
bufStr = bufStr[len(lfm2ThinkingOpenTag):]
|
||||
p.needsThinkingLeadingTrim = true
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
}
|
||||
|
||||
// Trim leading whitespace after <think> tag (may span multiple chunks)
|
||||
if p.needsThinkingLeadingTrim {
|
||||
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
|
||||
bufStr = trimmed
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
}
|
||||
// Clear flag once we have non-whitespace content or buffer is empty
|
||||
if len(bufStr) > 0 {
|
||||
p.needsThinkingLeadingTrim = false
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(bufStr, lfm2ThinkingCloseTag) { // thinking[</think>] -> content
|
||||
split := strings.SplitN(bufStr, lfm2ThinkingCloseTag, 2)
|
||||
thinking := split[0]
|
||||
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
|
||||
|
||||
remaining := split[1]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = LFM2CollectingContent
|
||||
p.needsThinkingLeadingTrim = false
|
||||
// Set flag to trim any additional whitespace that may arrive in later chunks
|
||||
p.needsContentLeadingTrim = len(remaining) == 0
|
||||
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, lfm2EventThinkingContent{content: thinking})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, lfm2ThinkingCloseTag); overlapLen > 0 { // partial </think>
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, lfm2EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else { // otherwise its thinking content
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, lfm2EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case LFM2CollectingContent:
|
||||
// Trim leading whitespace after </think> tag (may span multiple chunks)
|
||||
if p.needsContentLeadingTrim {
|
||||
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
|
||||
bufStr = trimmed
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
}
|
||||
// Clear flag once we have non-whitespace content
|
||||
if len(bufStr) > 0 {
|
||||
p.needsContentLeadingTrim = false
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(bufStr, lfm2ToolCallStartTag) { // content[<|tool_call_start|>] -> tool calls
|
||||
split := strings.SplitN(bufStr, lfm2ToolCallStartTag, 2)
|
||||
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = LFM2CollectingToolCalls
|
||||
|
||||
if len(contentBefore) > 0 {
|
||||
events = append(events, lfm2EventContent{content: contentBefore})
|
||||
}
|
||||
return events, true
|
||||
} else { // otherwise its content
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, lfm2EventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case LFM2CollectingToolCalls:
|
||||
// Look for complete tool call JSON between tags
|
||||
if idx := strings.Index(bufStr, lfm2ToolCallEndTag); idx != -1 {
|
||||
toolCallContent := bufStr[:idx]
|
||||
|
||||
if toolCalls, err := p.parseToolCallsContent(toolCallContent); err == nil && len(toolCalls) > 0 {
|
||||
remaining := bufStr[idx+len(lfm2ToolCallEndTag):]
|
||||
|
||||
// Check if there's another tool call
|
||||
if strings.HasPrefix(remaining, lfm2ToolCallStartTag) {
|
||||
remaining = remaining[len(lfm2ToolCallStartTag):]
|
||||
} else {
|
||||
// No more tool calls, go back to content
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
p.state = LFM2CollectingContent
|
||||
}
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
for _, tc := range toolCalls {
|
||||
events = append(events, lfm2EventToolCall{toolCall: tc})
|
||||
}
|
||||
return events, true
|
||||
} else if err != nil {
|
||||
slog.Warn("lfm2 tool call parsing failed", "error", err, "content", toolCallContent)
|
||||
}
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
// parseToolCallsContent parses one or more tool calls from content
|
||||
// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)]
|
||||
func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) {
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Try JSON format first: {"name": "func", "arguments": {...}}
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
|
||||
var args api.ToolCallFunctionArguments
|
||||
if len(parsed.Arguments) > 0 {
|
||||
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
args = api.NewToolCallFunctionArguments()
|
||||
}
|
||||
|
||||
return []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: parsed.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
|
||||
return p.parsePythonStyleToolCalls(content)
|
||||
}
|
||||
|
||||
// parsePythonStyleToolCalls parses one or more Python-style tool calls
|
||||
// Examples: [bash(command='ls'),bash(command='pwd')] or bash(command='ls')
|
||||
func (p *LFM2Parser) parsePythonStyleToolCalls(content string) ([]api.ToolCall, error) {
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Strip outer brackets if present: [func(...)] -> func(...)
|
||||
if strings.HasPrefix(content, "[") && strings.HasSuffix(content, "]") {
|
||||
content = content[1 : len(content)-1]
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
// Parse multiple function calls separated by commas at the top level
|
||||
for len(content) > 0 {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
break
|
||||
}
|
||||
|
||||
// Skip leading comma from previous iteration
|
||||
if strings.HasPrefix(content, ",") {
|
||||
content = strings.TrimSpace(content[1:])
|
||||
if content == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Find function name
|
||||
parenIdx := strings.Index(content, "(")
|
||||
if parenIdx == -1 {
|
||||
return nil, errors.New("invalid tool call: no opening parenthesis")
|
||||
}
|
||||
|
||||
funcName := strings.TrimSpace(content[:parenIdx])
|
||||
if funcName == "" {
|
||||
return nil, errors.New("invalid tool call: empty function name")
|
||||
}
|
||||
|
||||
// Find matching closing parenthesis
|
||||
closeIdx := findMatchingParen(content, parenIdx)
|
||||
if closeIdx == -1 {
|
||||
return nil, errors.New("invalid tool call: no matching closing parenthesis")
|
||||
}
|
||||
|
||||
argsStr := content[parenIdx+1 : closeIdx]
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
|
||||
if argsStr != "" {
|
||||
if err := parsePythonArgs(argsStr, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: funcName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
|
||||
// Move past this function call
|
||||
content = content[closeIdx+1:]
|
||||
}
|
||||
|
||||
if len(toolCalls) == 0 {
|
||||
return nil, errors.New("no tool calls found")
|
||||
}
|
||||
|
||||
return toolCalls, nil
|
||||
}
|
||||
|
||||
// findMatchingParen finds the index of the closing parenthesis matching the one at openIdx
|
||||
// Returns -1 if not found. Handles nested parentheses and quoted strings.
|
||||
func findMatchingParen(s string, openIdx int) int {
|
||||
depth := 1
|
||||
i := openIdx + 1
|
||||
for i < len(s) && depth > 0 {
|
||||
switch s[i] {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i
|
||||
}
|
||||
case '\'', '"':
|
||||
// Skip quoted string
|
||||
quote := s[i]
|
||||
i++
|
||||
for i < len(s) && s[i] != quote {
|
||||
if s[i] == '\\' && i+1 < len(s) {
|
||||
i++ // skip escaped char
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
i++
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// parseToolCallContent parses a single tool call (for backward compatibility with tests)
|
||||
func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error) {
|
||||
calls, err := p.parseToolCallsContent(content)
|
||||
if err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
if len(calls) == 0 {
|
||||
return api.ToolCall{}, errors.New("no tool call found")
|
||||
}
|
||||
return calls[0], nil
|
||||
}
|
||||
|
||||
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
|
||||
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
|
||||
// Simple state machine to parse key='value' pairs
|
||||
// Handles: command='ls', flag="-la", count=42, enabled=true
|
||||
var key string
|
||||
i := 0
|
||||
|
||||
for i < len(argsStr) {
|
||||
// Skip whitespace
|
||||
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
|
||||
i++
|
||||
}
|
||||
if i >= len(argsStr) {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse key
|
||||
keyStart := i
|
||||
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
|
||||
i++
|
||||
}
|
||||
if i >= len(argsStr) || argsStr[i] != '=' {
|
||||
return errors.New("invalid argument: expected '='")
|
||||
}
|
||||
key = strings.TrimSpace(argsStr[keyStart:i])
|
||||
i++ // skip '='
|
||||
|
||||
// Skip whitespace after =
|
||||
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
|
||||
i++
|
||||
}
|
||||
|
||||
// Parse value
|
||||
var value string
|
||||
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
|
||||
// Quoted string
|
||||
quote := argsStr[i]
|
||||
i++
|
||||
valueStart := i
|
||||
for i < len(argsStr) && argsStr[i] != quote {
|
||||
if argsStr[i] == '\\' && i+1 < len(argsStr) {
|
||||
i += 2 // skip escaped char
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
value = argsStr[valueStart:i]
|
||||
if i < len(argsStr) {
|
||||
i++ // skip closing quote
|
||||
}
|
||||
args.Set(key, value)
|
||||
} else {
|
||||
// Unquoted value (number, bool, etc)
|
||||
valueStart := i
|
||||
for i < len(argsStr) && argsStr[i] != ',' {
|
||||
i++
|
||||
}
|
||||
value = strings.TrimSpace(argsStr[valueStart:i])
|
||||
|
||||
// Try to parse as number or bool
|
||||
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
args.Set(key, v)
|
||||
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
|
||||
args.Set(key, v)
|
||||
} else if value == "true" {
|
||||
args.Set(key, true)
|
||||
} else if value == "false" {
|
||||
args.Set(key, false)
|
||||
} else {
|
||||
args.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Skip comma and whitespace
|
||||
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
1088
model/parsers/lfm2_test.go
Normal file
1088
model/parsers/lfm2_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -70,6 +70,10 @@ func ParserForName(name string) Parser {
|
||||
return &FunctionGemmaParser{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Parser{}
|
||||
case "lfm2":
|
||||
return &LFM2Parser{hasThinkingSupport: false}
|
||||
case "lfm2-thinking":
|
||||
return &LFM2Parser{hasThinkingSupport: true}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
144
model/renderers/lfm2.go
Normal file
144
model/renderers/lfm2.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type LFM2Renderer struct {
|
||||
IsThinking bool
|
||||
}
|
||||
|
||||
func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
// Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
|
||||
|
||||
// Extract first system message if present (to combine with tools)
|
||||
var firstSystemContent string
|
||||
startIdx := 0
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
firstSystemContent = messages[0].Content
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
// Append tools to first system content
|
||||
if len(tools) > 0 {
|
||||
if firstSystemContent != "" {
|
||||
firstSystemContent += "\n"
|
||||
}
|
||||
firstSystemContent += "List of tools: ["
|
||||
for i, tool := range tools {
|
||||
toolJSON, err := json.Marshal(tool)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
firstSystemContent += string(toolJSON)
|
||||
if i < len(tools)-1 {
|
||||
firstSystemContent += ", "
|
||||
}
|
||||
}
|
||||
firstSystemContent += "]"
|
||||
}
|
||||
|
||||
// Output first system block if it has content
|
||||
if firstSystemContent != "" {
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(firstSystemContent)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
// Find the index of the last assistant message for thinking stripping
|
||||
lastAssistantIndex := -1
|
||||
for i := len(messages) - 1; i >= startIdx; i-- {
|
||||
if messages[i].Role == "assistant" {
|
||||
lastAssistantIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Track whether we need to add generation prompt
|
||||
needsGenerationPrompt := len(messages) > 0
|
||||
|
||||
for i := startIdx; i < len(messages); i++ {
|
||||
message := messages[i]
|
||||
switch message.Role {
|
||||
case "system":
|
||||
// Additional system messages (after the first) are rendered normally
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
|
||||
case "user":
|
||||
sb.WriteString("<|im_start|>user\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
|
||||
// Check if this is the last assistant message
|
||||
isLastAssistant := i == lastAssistantIndex
|
||||
|
||||
// Process content (may need thinking stripped)
|
||||
content := message.Content
|
||||
|
||||
// Handle thinking tags in assistant content
|
||||
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
|
||||
if strings.Contains(content, "</think>") {
|
||||
parts := strings.SplitN(content, "</think>", 2)
|
||||
if len(parts) > 1 {
|
||||
if !isLastAssistant && !keepPastThinking {
|
||||
// Strip thinking entirely for past assistant messages
|
||||
content = strings.TrimSpace(parts[1])
|
||||
} else {
|
||||
// Preserve thinking but trim whitespace after </think>
|
||||
content = parts[0] + "</think>" + strings.TrimLeft(parts[1], " \t\n\r")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
// Assistant with tool calls - write content first (if any after stripping)
|
||||
if content != "" {
|
||||
sb.WriteString(content)
|
||||
}
|
||||
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<|tool_call_start|>")
|
||||
toolCallJSON := map[string]any{
|
||||
"name": toolCall.Function.Name,
|
||||
"arguments": toolCall.Function.Arguments,
|
||||
}
|
||||
callJSON, _ := json.Marshal(toolCallJSON)
|
||||
sb.WriteString(string(callJSON))
|
||||
sb.WriteString("<|tool_call_end|>")
|
||||
}
|
||||
} else {
|
||||
sb.WriteString(content)
|
||||
}
|
||||
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
|
||||
|
||||
case "tool":
|
||||
// Tool responses are rendered as plain messages per the chat template
|
||||
sb.WriteString("<|im_start|>tool\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt
|
||||
if needsGenerationPrompt {
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
// Note: Model is a "thinking-only" model - it will output <think> itself
|
||||
// We don't add <think> tag to the prompt
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
427
model/renderers/lfm2_test.go
Normal file
427
model/renderers/lfm2_test.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestLFM2Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic user message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple system messages rendered separately",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "First instruction."},
|
||||
{Role: "system", Content: "Second instruction."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
{Role: "assistant", Content: "The answer is 4."},
|
||||
{Role: "user", Content: "Thanks!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "only system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
|
||||
name: "user-assistant-user: last assistant preserves thinking",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reasoning</think>A1"},
|
||||
{Role: "user", Content: "Q2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reasoning</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// With two assistants, first is stripped (not last), second preserved (is last)
|
||||
name: "multi-turn thinking: first stripped, second preserved",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reason1</think>A1"},
|
||||
{Role: "user", Content: "Q2"},
|
||||
{Role: "assistant", Content: "<think>reason2</think>A2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// With thinking enabled (keep_past_thinking=true), both preserved
|
||||
name: "multi-turn thinking: both preserved when thinking enabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reason1</think>A1"},
|
||||
{Role: "user", Content: "Q2"},
|
||||
{Role: "assistant", Content: "<think>reason2</think>A2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with content and tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tool response",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{Role: "assistant", Content: "Let me check."},
|
||||
{Role: "tool", Content: "22C, Sunny"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Get weather for Paris and London"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tools definitions with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tools definitions without system message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "multiple tools without system message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_time",
|
||||
Description: "Get time",
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "user-tool sequence",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Check weather"},
|
||||
{Role: "tool", Content: "22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "full tool call cycle",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Check weather"},
|
||||
{Role: "assistant", Content: "Let me check"},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "It's 22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "unicode content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "你好世界! مرحبا 🌍"},
|
||||
{Role: "assistant", Content: "Hello! 👋"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "newlines in content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
|
||||
{Role: "assistant", Content: "Response with\nmultiple\nlines"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "empty assistant content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: ""},
|
||||
{Role: "user", Content: "OK"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Generation prompt does NOT include <think> - model outputs it
|
||||
name: "generation prompt has no think tag",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Think hard"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Interleaved: thinking before tool call - last assistant preserves thinking
|
||||
name: "thinking before tool call (last assistant)",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>I need to check the weather</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>I need to check the weather</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Two assistants with tool calls - first has thinking stripped
|
||||
name: "two assistants with tools: first thinking stripped",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>checking</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Two assistants with tools - both preserved when thinking enabled
|
||||
name: "two assistants with tools: both preserved when thinking enabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>checking</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>checking</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Content before thinking before tool call
|
||||
name: "content then thinking then tool call",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.<think>Using weather API</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<think>Using weather API</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
}
|
||||
|
||||
renderer := &LFM2Renderer{IsThinking: true}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
|
||||
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -82,6 +82,10 @@ func rendererForName(name string) Renderer {
|
||||
return &FunctionGemmaRenderer{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Renderer{}
|
||||
case "lfm2":
|
||||
return &LFM2Renderer{IsThinking: false}
|
||||
case "lfm2-thinking":
|
||||
return &LFM2Renderer{IsThinking: true}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -90,7 +91,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
oldManifest, _ := ParseNamedManifest(name)
|
||||
oldManifest, _ := manifest.ParseNamedManifest(name)
|
||||
|
||||
var baseLayers []*layerGGML
|
||||
var err error
|
||||
@@ -123,9 +124,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
|
||||
manifest, mErr := ParseNamedManifest(fromName)
|
||||
if mErr == nil && manifest.Config.Digest != "" {
|
||||
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
|
||||
mf, mErr := manifest.ParseNamedManifest(fromName)
|
||||
if mErr == nil && mf.Config.Digest != "" {
|
||||
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
|
||||
if pErr == nil {
|
||||
if cfgFile, fErr := os.Open(configPath); fErr == nil {
|
||||
var baseConfig model.ConfigV2
|
||||
@@ -342,7 +343,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
|
||||
return "gguf"
|
||||
} else {
|
||||
// try to see if we can find a gguf file even without the file extension
|
||||
blobPath, err := GetBlobsPath(files[fn])
|
||||
blobPath, err := manifest.BlobsPath(files[fn])
|
||||
if err != nil {
|
||||
slog.Error("error getting blobs path", "file", fn)
|
||||
return ""
|
||||
@@ -394,7 +395,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
||||
}
|
||||
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -432,7 +433,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(t, mediaType)
|
||||
layer, err := manifest.NewLayer(t, mediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -465,7 +466,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||
var layers []Layer
|
||||
var layers []manifest.Layer
|
||||
for _, layer := range baseLayers {
|
||||
if layer.GGML != nil {
|
||||
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
||||
@@ -550,13 +551,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if layer.status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.status})
|
||||
if layer.Status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.Status})
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
if err := WriteManifest(name, *configLayer, layers); err != nil {
|
||||
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -577,7 +578,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
blob, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -599,7 +600,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
}
|
||||
temp.Seek(0, io.SeekStart)
|
||||
fn(api.ProgressResponse{Status: "verifying conversion"})
|
||||
newLayer, err := NewLayer(temp, layer.MediaType)
|
||||
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -619,7 +620,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
var layers []*layerGGML
|
||||
|
||||
fn(api.ProgressResponse{Status: "parsing GGUF"})
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -654,7 +655,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
if err != nil {
|
||||
slog.Debug("could not create new layer from layer", "error", err)
|
||||
return nil, err
|
||||
@@ -665,8 +666,8 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
return detectChatTemplate(layers)
|
||||
}
|
||||
|
||||
func removeLayer(layers []Layer, mediatype string) []Layer {
|
||||
return slices.DeleteFunc(layers, func(layer Layer) bool {
|
||||
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
|
||||
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
|
||||
if layer.MediaType != mediatype {
|
||||
return false
|
||||
}
|
||||
@@ -680,7 +681,7 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
|
||||
})
|
||||
}
|
||||
|
||||
func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.template")
|
||||
if _, err := template.Parse(t); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||
@@ -690,7 +691,7 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
}
|
||||
|
||||
blob := strings.NewReader(t)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.template")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -699,11 +700,11 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setSystem(layers []Layer, s string) ([]Layer, error) {
|
||||
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.system")
|
||||
if s != "" {
|
||||
blob := strings.NewReader(s)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.system")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -712,9 +713,9 @@ func setSystem(layers []Layer, s string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setLicense(layers []Layer, l string) ([]Layer, error) {
|
||||
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
|
||||
blob := strings.NewReader(l)
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.license")
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -722,7 +723,7 @@ func setLicense(layers []Layer, l string) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
|
||||
if p == nil {
|
||||
p = make(map[string]any)
|
||||
}
|
||||
@@ -731,7 +732,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
digestPath, err := GetBlobsPath(layer.Digest)
|
||||
digestPath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -765,7 +766,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -773,7 +774,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
|
||||
// this leaves the old messages intact if no new messages were specified
|
||||
// which may not be the correct behaviour
|
||||
if len(m) == 0 {
|
||||
@@ -786,7 +787,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -794,7 +795,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
|
||||
digests := make([]string, len(layers))
|
||||
for i, layer := range layers {
|
||||
digests[i] = layer.Digest
|
||||
@@ -805,7 +806,7 @@ func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||
if err := json.NewEncoder(&b).Encode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
)
|
||||
|
||||
func TestConvertFromSafetensors(t *testing.T) {
|
||||
@@ -17,7 +18,7 @@ func TestConvertFromSafetensors(t *testing.T) {
|
||||
|
||||
// Helper function to create a new layer and return its digest
|
||||
makeTemp := func(content string) string {
|
||||
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create layer: %v", err)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const maxRetries = 6
|
||||
@@ -456,7 +458,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
}
|
||||
|
||||
type downloadOpts struct {
|
||||
mp ModelPath
|
||||
n model.Name
|
||||
digest string
|
||||
regOpts *registryOptions
|
||||
fn func(api.ProgressResponse)
|
||||
@@ -465,10 +467,10 @@ type downloadOpts struct {
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||
if opts.digest == "" {
|
||||
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
|
||||
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
|
||||
}
|
||||
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
fp, err := manifest.BlobsPath(opts.digest)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -492,8 +494,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
|
||||
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
||||
download := data.(*blobDownload)
|
||||
if !ok {
|
||||
requestURL := opts.mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
|
||||
requestURL := opts.n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
|
||||
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
||||
blobDownloadManager.Delete(opts.digest)
|
||||
return false, err
|
||||
|
||||
205
server/images.go
205
server/images.go
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -24,6 +23,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -274,44 +274,22 @@ func (m *Model) String() string {
|
||||
return modelfile.String()
|
||||
}
|
||||
|
||||
func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
f, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
sha256sum := sha256.New()
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func GetModel(name string) (*Model, error) {
|
||||
mp := ParseModelPath(name)
|
||||
manifest, digest, err := GetManifest(mp)
|
||||
n := model.ParseName(name)
|
||||
mf, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model := &Model{
|
||||
Name: mp.GetFullTagname(),
|
||||
ShortName: mp.GetShortTagname(),
|
||||
Digest: digest,
|
||||
m := &Model{
|
||||
Name: n.String(),
|
||||
ShortName: n.DisplayShortest(),
|
||||
Digest: mf.Digest(),
|
||||
Template: template.DefaultTemplate,
|
||||
}
|
||||
|
||||
if manifest.Config.Digest != "" {
|
||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||
if mf.Config.Digest != "" {
|
||||
filename, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -322,29 +300,29 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer configFile.Close()
|
||||
|
||||
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
||||
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
filename, err := GetBlobsPath(layer.Digest)
|
||||
for _, layer := range mf.Layers {
|
||||
filename, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model":
|
||||
model.ModelPath = filename
|
||||
model.ParentModel = layer.From
|
||||
m.ModelPath = filename
|
||||
m.ParentModel = layer.From
|
||||
case "application/vnd.ollama.image.embed":
|
||||
// Deprecated in versions > 0.1.2
|
||||
// TODO: remove this warning in a future version
|
||||
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||
case "application/vnd.ollama.image.adapter":
|
||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||
m.AdapterPaths = append(m.AdapterPaths, filename)
|
||||
case "application/vnd.ollama.image.projector":
|
||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
||||
m.ProjectorPaths = append(m.ProjectorPaths, filename)
|
||||
case "application/vnd.ollama.image.prompt",
|
||||
"application/vnd.ollama.image.template":
|
||||
bts, err := os.ReadFile(filename)
|
||||
@@ -352,7 +330,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model.Template, err = template.Parse(string(bts))
|
||||
m.Template, err = template.Parse(string(bts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -362,7 +340,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model.System = string(bts)
|
||||
m.System = string(bts)
|
||||
case "application/vnd.ollama.image.params":
|
||||
params, err := os.Open(filename)
|
||||
if err != nil {
|
||||
@@ -371,7 +349,7 @@ func GetModel(name string) (*Model, error) {
|
||||
defer params.Close()
|
||||
|
||||
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
||||
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
||||
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.messages":
|
||||
@@ -381,7 +359,7 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer msgs.Close()
|
||||
|
||||
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
||||
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.license":
|
||||
@@ -389,11 +367,11 @@ func GetModel(name string) (*Model, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.License = append(model.License, string(bts))
|
||||
m.License = append(m.License, string(bts))
|
||||
}
|
||||
}
|
||||
|
||||
return model, nil
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func CopyModel(src, dst model.Name) error {
|
||||
@@ -408,7 +386,7 @@ func CopyModel(src, dst model.Name) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
manifests, err := manifest.Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -437,7 +415,7 @@ func CopyModel(src, dst model.Name) error {
|
||||
|
||||
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
|
||||
manifests, err := Manifests(true)
|
||||
manifests, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -452,7 +430,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
|
||||
// only delete the files which are still in the deleteMap
|
||||
for k := range deleteMap {
|
||||
fp, err := GetBlobsPath(k)
|
||||
fp, err := manifest.BlobsPath(k)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
|
||||
continue
|
||||
@@ -468,7 +446,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
|
||||
func PruneLayers() error {
|
||||
deleteMap := make(map[string]struct{})
|
||||
p, err := GetBlobsPath("")
|
||||
p, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -483,9 +461,9 @@ func PruneLayers() error {
|
||||
name := blob.Name()
|
||||
name = strings.ReplaceAll(name, "-", ":")
|
||||
|
||||
_, err := GetBlobsPath(name)
|
||||
_, err := manifest.BlobsPath(name)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInvalidDigestFormat) {
|
||||
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
|
||||
// remove invalid blobs (e.g. partial downloads)
|
||||
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
|
||||
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
|
||||
@@ -510,63 +488,30 @@ func PruneLayers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func PruneDirectory(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
entries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
n := model.ParseName(name)
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
manifest, _, err := GetManifest(mp)
|
||||
mf, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
|
||||
return err
|
||||
}
|
||||
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, manifest.Config)
|
||||
var layers []manifest.Layer
|
||||
layers = append(layers, mf.Layers...)
|
||||
if mf.Config.Digest != "" {
|
||||
layers = append(layers, mf.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
// Read raw manifest JSON to preserve tensor metadata fields
|
||||
manifestPath, err := mp.GetManifestPath()
|
||||
manifestPath, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -574,7 +519,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
if err := pushWithTransfer(ctx, n, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
@@ -582,17 +527,17 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
if err := uploadBlob(ctx, n, layer, regOpts, fn); err != nil {
|
||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pushing manifest"})
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -611,44 +556,44 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
n := model.ParseName(name)
|
||||
|
||||
// build deleteMap to prune unused layers
|
||||
deleteMap := make(map[string]struct{})
|
||||
manifest, _, err := GetManifest(mp)
|
||||
existingMf, err := manifest.ParseNamedManifest(n)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// noop
|
||||
} else if err != nil {
|
||||
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
|
||||
} else {
|
||||
for _, l := range manifest.Layers {
|
||||
for _, l := range existingMf.Layers {
|
||||
deleteMap[l.Digest] = struct{}{}
|
||||
}
|
||||
if manifest.Config.Digest != "" {
|
||||
deleteMap[manifest.Config.Digest] = struct{}{}
|
||||
if existingMf.Config.Digest != "" {
|
||||
deleteMap[existingMf.Config.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
||||
mf, err := pullModelManifest(ctx, n, regOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pull model manifest: %s", err)
|
||||
}
|
||||
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, manifest.Config)
|
||||
var layers []manifest.Layer
|
||||
layers = append(layers, mf.Layers...)
|
||||
if mf.Config.Digest != "" {
|
||||
layers = append(layers, mf.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
|
||||
if err := pullWithTransfer(ctx, n, layers, mf, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
@@ -658,7 +603,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
mp: mp,
|
||||
n: n,
|
||||
digest: layer.Digest,
|
||||
regOpts: regOpts,
|
||||
fn: fn,
|
||||
@@ -677,7 +622,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
fp, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -692,16 +637,16 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
for _, layer := range layers {
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
delete(deleteMap, mf.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -728,9 +673,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
// hasTensorLayers checks if any layer has tensor media type.
|
||||
func hasTensorLayers(layers []Layer) bool {
|
||||
func hasTensorLayers(layers []manifest.Layer) bool {
|
||||
for _, layer := range layers {
|
||||
if layer.MediaType == MediaTypeImageTensor {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -738,7 +683,7 @@ func hasTensorLayers(layers []Layer) bool {
|
||||
}
|
||||
|
||||
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
|
||||
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, mf *manifest.Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
@@ -747,12 +692,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
}
|
||||
|
||||
destDir, err := GetBlobsPath("")
|
||||
destDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
base := n.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
@@ -784,7 +729,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Blobs: blobs,
|
||||
BaseURL: baseURL,
|
||||
DestDir: destDir,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
Repository: n.DisplayNamespaceModel(),
|
||||
Progress: progress,
|
||||
Token: regOpts.Token,
|
||||
GetToken: getToken,
|
||||
@@ -795,12 +740,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
|
||||
// Write manifest
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath()
|
||||
fp, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -812,7 +757,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
|
||||
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
|
||||
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
@@ -822,12 +767,12 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
}
|
||||
}
|
||||
|
||||
srcDir, err := GetBlobsPath("")
|
||||
srcDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := mp.BaseURL()
|
||||
base := n.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
@@ -864,13 +809,13 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
GetToken: getToken,
|
||||
Logger: slog.Default(),
|
||||
Manifest: manifestJSON,
|
||||
ManifestRef: mp.Tag,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
ManifestRef: n.Tag,
|
||||
Repository: n.DisplayNamespaceModel(),
|
||||
})
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
|
||||
requestURL := n.BaseURL().JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
@@ -880,7 +825,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var m Manifest
|
||||
var m manifest.Manifest
|
||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1042,7 +987,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
|
||||
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
|
||||
|
||||
func verifyBlob(digest string) error {
|
||||
fp, err := GetBlobsPath(digest)
|
||||
fp, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -20,19 +21,19 @@ import (
|
||||
var intermediateBlobs map[string]string = make(map[string]string)
|
||||
|
||||
type layerGGML struct {
|
||||
Layer
|
||||
manifest.Layer
|
||||
*ggml.GGML
|
||||
}
|
||||
|
||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||
m, err := ParseNamedManifest(name)
|
||||
m, err := manifest.ParseNamedManifest(name)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err = ParseNamedManifest(name)
|
||||
m, err = manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -41,7 +42,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -50,7 +51,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
case "application/vnd.ollama.image.model",
|
||||
"application/vnd.ollama.image.projector",
|
||||
"application/vnd.ollama.image.adapter":
|
||||
blobpath, err := GetBlobsPath(layer.Digest)
|
||||
blobpath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -81,12 +82,12 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
if t, err := template.Named(s); err != nil {
|
||||
slog.Debug("template detection", "error", err, "template", s)
|
||||
} else {
|
||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
|
||||
if t.Parameters != nil {
|
||||
@@ -95,7 +96,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type ModelPath struct {
|
||||
ProtocolScheme string
|
||||
Registry string
|
||||
Namespace string
|
||||
Repository string
|
||||
Tag string
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultRegistry = "registry.ollama.ai"
|
||||
DefaultNamespace = "library"
|
||||
DefaultTag = "latest"
|
||||
DefaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
ErrModelPathInvalid = errors.New("invalid model path")
|
||||
)
|
||||
|
||||
func ParseModelPath(name string) ModelPath {
|
||||
mp := ModelPath{
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "",
|
||||
Tag: DefaultTag,
|
||||
}
|
||||
|
||||
before, after, found := strings.Cut(name, "://")
|
||||
if found {
|
||||
mp.ProtocolScheme = before
|
||||
name = after
|
||||
}
|
||||
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
mp.Registry = parts[0]
|
||||
mp.Namespace = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
case 2:
|
||||
mp.Namespace = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
case 1:
|
||||
mp.Repository = parts[0]
|
||||
}
|
||||
|
||||
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
|
||||
mp.Repository = repo
|
||||
mp.Tag = tag
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetFullTagname() string {
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetShortTagname() string {
|
||||
if mp.Registry == DefaultRegistry {
|
||||
if mp.Namespace == DefaultNamespace {
|
||||
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||
func (mp ModelPath) GetManifestPath() (string, error) {
|
||||
name := model.Name{
|
||||
Host: mp.Registry,
|
||||
Namespace: mp.Namespace,
|
||||
Model: mp.Repository,
|
||||
Tag: mp.Tag,
|
||||
}
|
||||
if !name.IsValid() {
|
||||
return "", fs.ErrNotExist
|
||||
}
|
||||
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
|
||||
}
|
||||
|
||||
func (mp ModelPath) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: mp.ProtocolScheme,
|
||||
Host: mp.Registry,
|
||||
}
|
||||
}
|
||||
|
||||
func GetManifestPath() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func GetBlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetBlobsPath(t *testing.T) {
|
||||
// GetBlobsPath expects an actual directory to exist
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
digest string
|
||||
expected string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"empty digest",
|
||||
"",
|
||||
filepath.Join(tempDir, "blobs"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with colon",
|
||||
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with dash",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"digest too short",
|
||||
"sha256-45640291",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest too long",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest invalid chars",
|
||||
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
got, err := GetBlobsPath(tc.digest)
|
||||
|
||||
require.ErrorIs(t, tc.err, err, tc.name)
|
||||
assert.Equal(t, tc.expected, got, tc.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg string
|
||||
want ModelPath
|
||||
}{
|
||||
{
|
||||
"full path https",
|
||||
"https://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"full path http",
|
||||
"http://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "http",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no protocol",
|
||||
"example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no registry",
|
||||
"ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no namespace",
|
||||
"repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no tag",
|
||||
"repo",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: DefaultTag,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ParseModelPath(tc.arg)
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got: %q want: %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -198,8 +198,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
name := t.Name
|
||||
quantize := strings.HasSuffix(name, "weight")
|
||||
|
||||
// don't quantize vision stuff
|
||||
quantize = quantize && (!strings.Contains(name, "v.") || strings.Contains(name, "_v."))
|
||||
// don't quantize vision encoder tensors (named with "v." prefix)
|
||||
quantize = quantize && !strings.HasPrefix(name, "v.")
|
||||
quantize = quantize && !strings.Contains(name, "mm.")
|
||||
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
@@ -219,6 +219,9 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||
quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight")
|
||||
|
||||
// do not quantize LFM2's shortconv kernel weights
|
||||
quantize = quantize && !strings.Contains(name, "shortconv.conv.weight")
|
||||
|
||||
// do not quantize RWKV's time_mix_first tensors
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_first.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")
|
||||
|
||||
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/middleware"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
@@ -974,7 +975,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
// is.
|
||||
func getExistingName(n model.Name) (model.Name, error) {
|
||||
var zero model.Name
|
||||
existing, err := Manifests(true)
|
||||
existing, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
@@ -1018,7 +1019,7 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(n)
|
||||
m, err := manifest.ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
@@ -1080,7 +1081,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
return nil, ErrModelPathInvalid
|
||||
return nil, model.Unqualified(name)
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
@@ -1112,7 +1113,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For safetensors LLM models (experimental), populate details from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
||||
modelDetails.Family = arch
|
||||
}
|
||||
@@ -1121,7 +1122,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
}
|
||||
// Get torch_dtype directly from config.json for quantization level
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
|
||||
modelDetails.QuantizationLevel = dtype
|
||||
}
|
||||
}
|
||||
@@ -1135,7 +1136,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||
}
|
||||
|
||||
manifest, err := ParseNamedManifest(name)
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1147,8 +1148,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
Capabilities: m.Capabilities(),
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
ModifiedAt: mf.FileInfo().ModTime(),
|
||||
Requires: m.Config.Requires,
|
||||
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
|
||||
// default we return an empty map.
|
||||
ModelInfo: make(map[string]any),
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" {
|
||||
@@ -1211,7 +1215,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
@@ -1220,12 +1224,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||
resp.ModelInfo = info
|
||||
}
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
@@ -1282,7 +1286,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
|
||||
}
|
||||
|
||||
func (s *Server) ListHandler(c *gin.Context) {
|
||||
ms, err := Manifests(true)
|
||||
ms, err := manifest.Manifests(true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1313,8 +1317,8 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
RemoteModel: cf.RemoteModel,
|
||||
RemoteHost: cf.RemoteHost,
|
||||
Size: m.Size(),
|
||||
Digest: m.digest,
|
||||
ModifiedAt: m.fi.ModTime(),
|
||||
Digest: m.Digest(),
|
||||
ModifiedAt: m.FileInfo().ModTime(),
|
||||
Details: api.ModelDetails{
|
||||
Format: cf.ModelFormat,
|
||||
Family: cf.ModelFamily,
|
||||
@@ -1373,7 +1377,7 @@ func (s *Server) CopyHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
path, err := manifest.BlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1389,7 +1393,7 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||
|
||||
func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
|
||||
p, err := GetBlobsPath(ib)
|
||||
p, err := manifest.BlobsPath(ib)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1407,7 +1411,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
path, err := manifest.BlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1425,7 +1429,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
layer, err := manifest.NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1625,7 +1629,7 @@ func Serve(ln net.Listener) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
slog.Info("server config", "env", envconfig.Values())
|
||||
|
||||
blobsDir, err := GetBlobsPath("")
|
||||
blobsDir, err := manifest.BlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1634,7 +1638,7 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
|
||||
if !envconfig.NoPrune() {
|
||||
if _, err := Manifests(false); err != nil {
|
||||
if _, err := manifest.Manifests(false); err != nil {
|
||||
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
|
||||
} else {
|
||||
// clean up unused layers and manifests
|
||||
@@ -1642,12 +1646,12 @@ func Serve(ln net.Listener) error {
|
||||
return err
|
||||
}
|
||||
|
||||
manifestsPath, err := GetManifestPath()
|
||||
manifestsPath, err := manifest.Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := PruneDirectory(manifestsPath); err != nil {
|
||||
if err := manifest.PruneDirectory(manifestsPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -223,15 +224,15 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
manifest, err := ParseNamedManifest(model.ParseName("child"))
|
||||
mf, err := manifest.ParseNamedManifest(model.ParseName("child"))
|
||||
if err != nil {
|
||||
t.Fatalf("parse manifest: %v", err)
|
||||
}
|
||||
if manifest.Config.Digest == "" {
|
||||
if mf.Config.Digest == "" {
|
||||
t.Fatalf("unexpected empty config digest for child manifest")
|
||||
}
|
||||
|
||||
configPath, err := GetBlobsPath(manifest.Config.Digest)
|
||||
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("config blob path: %v", err)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -93,13 +94,13 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create a manifest with duplicate layers
|
||||
if err := WriteManifest(n, config, []Layer{config}); err != nil {
|
||||
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -21,12 +21,14 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var blobUploadManager sync.Map
|
||||
|
||||
type blobUpload struct {
|
||||
Layer
|
||||
manifest.Layer
|
||||
|
||||
Total int64
|
||||
Completed atomic.Int64
|
||||
@@ -51,7 +53,7 @@ const (
|
||||
)
|
||||
|
||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
p, err := manifest.BlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -59,7 +61,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
|
||||
if b.From != "" {
|
||||
values := requestURL.Query()
|
||||
values.Add("mount", b.Digest)
|
||||
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
|
||||
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
|
||||
requestURL.RawQuery = values.Encode()
|
||||
}
|
||||
|
||||
@@ -128,7 +130,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||
defer blobUploadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
p, err := manifest.BlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
@@ -364,9 +366,9 @@ func (p *progressWriter) Rollback() {
|
||||
p.written = 0
|
||||
}
|
||||
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||
switch {
|
||||
@@ -388,8 +390,8 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOp
|
||||
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
|
||||
upload := data.(*blobUpload)
|
||||
if !ok {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
|
||||
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
|
||||
blobUploadManager.Delete(layer.Digest)
|
||||
return err
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
@@ -35,22 +36,25 @@ func Unqualified(n Name) error {
|
||||
const MissingPart = "!MISSING!"
|
||||
|
||||
const (
|
||||
defaultHost = "registry.ollama.ai"
|
||||
defaultNamespace = "library"
|
||||
defaultTag = "latest"
|
||||
defaultHost = "registry.ollama.ai"
|
||||
defaultNamespace = "library"
|
||||
defaultTag = "latest"
|
||||
defaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
// DefaultName returns a name with the default values for the host, namespace,
|
||||
// and tag parts. The model and digest parts are empty.
|
||||
// tag, and protocol scheme parts. The model and digest parts are empty.
|
||||
//
|
||||
// - The default host is ("registry.ollama.ai")
|
||||
// - The default namespace is ("library")
|
||||
// - The default tag is ("latest")
|
||||
// - The default protocol scheme is ("https")
|
||||
func DefaultName() Name {
|
||||
return Name{
|
||||
Host: defaultHost,
|
||||
Namespace: defaultNamespace,
|
||||
Tag: defaultTag,
|
||||
Host: defaultHost,
|
||||
Namespace: defaultNamespace,
|
||||
Tag: defaultTag,
|
||||
ProtocolScheme: defaultProtocolScheme,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,10 +91,11 @@ func (k partKind) String() string {
|
||||
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
|
||||
// is valid.
|
||||
type Name struct {
|
||||
Host string
|
||||
Namespace string
|
||||
Model string
|
||||
Tag string
|
||||
Host string
|
||||
Namespace string
|
||||
Model string
|
||||
Tag string
|
||||
ProtocolScheme string
|
||||
}
|
||||
|
||||
// ParseName parses and assembles a Name from a name string. The
|
||||
@@ -160,7 +165,9 @@ func ParseNameBare(s string) Name {
|
||||
}
|
||||
|
||||
scheme, host, ok := strings.Cut(s, "://")
|
||||
if !ok {
|
||||
if ok {
|
||||
n.ProtocolScheme = scheme
|
||||
} else {
|
||||
host = scheme
|
||||
}
|
||||
n.Host = host
|
||||
@@ -189,12 +196,13 @@ func ParseNameFromFilepath(s string) (n Name) {
|
||||
return n
|
||||
}
|
||||
|
||||
// Merge merges the host, namespace, and tag parts of the two names,
|
||||
// Merge merges the host, namespace, tag, and protocol scheme parts of the two names,
|
||||
// preferring the non-empty parts of a.
|
||||
func Merge(a, b Name) Name {
|
||||
a.Host = cmp.Or(a.Host, b.Host)
|
||||
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
|
||||
a.Tag = cmp.Or(a.Tag, b.Tag)
|
||||
a.ProtocolScheme = cmp.Or(a.ProtocolScheme, b.ProtocolScheme)
|
||||
return a
|
||||
}
|
||||
|
||||
@@ -305,6 +313,23 @@ func (n Name) EqualFold(o Name) bool {
|
||||
strings.EqualFold(n.Tag, o.Tag)
|
||||
}
|
||||
|
||||
// BaseURL returns the base URL for the registry.
|
||||
func (n Name) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: n.ProtocolScheme,
|
||||
Host: n.Host,
|
||||
}
|
||||
}
|
||||
|
||||
// DisplayNamespaceModel returns the namespace and model joined by "/".
|
||||
func (n Name) DisplayNamespaceModel() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(n.Namespace)
|
||||
b.WriteByte('/')
|
||||
b.WriteString(n.Model)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isValidLen(kind partKind, s string) bool {
|
||||
switch kind {
|
||||
case kindHost:
|
||||
|
||||
@@ -32,10 +32,11 @@ func TestParseNameParts(t *testing.T) {
|
||||
{
|
||||
in: "scheme://host:port/namespace/model:tag",
|
||||
want: Name{
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
ProtocolScheme: "scheme",
|
||||
},
|
||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
|
||||
},
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
@@ -103,7 +103,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
||||
func newLayerCreator() create.LayerCreator {
|
||||
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
layer, err := manifest.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
@@ -141,13 +141,13 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -169,7 +169,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -186,7 +186,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
layer, err := manifest.NewLayer(r, manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -221,15 +221,15 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
configLayer, err := manifest.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer
|
||||
serverLayers := make([]server.Layer, 0, len(layers))
|
||||
// Convert LayerInfo to manifest.Layer
|
||||
manifestLayers := make([]manifest.Layer, 0, len(layers))
|
||||
for _, l := range layers {
|
||||
serverLayers = append(serverLayers, server.Layer{
|
||||
manifestLayers = append(manifestLayers, manifest.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
@@ -243,19 +243,19 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverLayers = append(serverLayers, modelfileLayers...)
|
||||
manifestLayers = append(manifestLayers, modelfileLayers...)
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
return manifest.WriteManifest(name, configLayer, manifestLayers)
|
||||
}
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
var layers []server.Layer
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
var layers []manifest.Layer
|
||||
|
||||
if mf.Template != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
||||
}
|
||||
@@ -263,7 +263,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
}
|
||||
|
||||
if mf.System != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
||||
}
|
||||
@@ -271,7 +271,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
}
|
||||
|
||||
if mf.License != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
||||
}
|
||||
|
||||
6
x/imagegen/cache/step.go
vendored
6
x/imagegen/cache/step.go
vendored
@@ -9,7 +9,7 @@ import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
// shallow layers change little between consecutive steps, so we can
|
||||
// cache their outputs and skip recomputation on non-refresh steps.
|
||||
//
|
||||
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
|
||||
// Supports both single-stream and dual-stream architectures:
|
||||
// - Single-stream: use Get/Set for the single output per layer
|
||||
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
|
||||
//
|
||||
@@ -87,7 +87,7 @@ func (c *StepCache) Set(layer int, arr *mlx.Array) {
|
||||
}
|
||||
|
||||
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
// Used for dual-stream architectures.
|
||||
func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||
if layer < len(c.layers2) {
|
||||
return c.layers2[layer]
|
||||
@@ -96,7 +96,7 @@ func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||
}
|
||||
|
||||
// Set2 stores a layer output (stream 2), freeing any previous value.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
// Used for dual-stream architectures.
|
||||
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
|
||||
if layer < len(c.layers2) {
|
||||
if c.layers2[layer] != nil {
|
||||
|
||||
@@ -21,8 +21,6 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
@@ -61,14 +59,11 @@ func main() {
|
||||
listTensors := flag.Bool("list", false, "List tensors only")
|
||||
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
||||
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
||||
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
|
||||
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
@@ -166,60 +161,6 @@ func main() {
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImage:
|
||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
LayerCache: *layerCache,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImageEdit:
|
||||
if len(inputImages) == 0 {
|
||||
log.Fatal("qwen-image-edit requires at least one -input-image")
|
||||
}
|
||||
|
||||
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// For image editing, use 0 for dimensions to auto-detect from input image
|
||||
// unless explicitly overridden from defaults
|
||||
editWidth := int32(0)
|
||||
editHeight := int32(0)
|
||||
if *width != 1024 {
|
||||
editWidth = int32(*width)
|
||||
}
|
||||
if *height != 1024 {
|
||||
editHeight = int32(*height)
|
||||
}
|
||||
|
||||
cfg := &qwen_image_edit.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: editWidth,
|
||||
Height: editHeight,
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
}
|
||||
|
||||
var img *mlx.Array
|
||||
img, err = m.EditFromConfig(inputImages, cfg)
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *listTensors:
|
||||
err = listModelTensors(*modelPath)
|
||||
default:
|
||||
|
||||
@@ -6,8 +6,9 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
@@ -32,31 +33,15 @@ type ModelManifest struct {
|
||||
BlobDir string
|
||||
}
|
||||
|
||||
// DefaultBlobDir returns the default blob storage directory.
|
||||
func DefaultBlobDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "linux":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "windows":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
default:
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
}
|
||||
return filepath.Join(envconfig.Models(), "blobs")
|
||||
}
|
||||
|
||||
// DefaultManifestDir returns the default manifest storage directory.
|
||||
// DefaultManifestDir returns the manifest storage directory.
|
||||
// Respects OLLAMA_MODELS.
|
||||
|
||||
func DefaultManifestDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "models", "manifests")
|
||||
return filepath.Join(envconfig.Models(), "manifests")
|
||||
}
|
||||
|
||||
// LoadManifest loads a manifest for the given model name.
|
||||
|
||||
26
x/imagegen/manifest_test.go
Normal file
26
x/imagegen/manifest_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
|
||||
modelsDir := filepath.Join(t.TempDir(), "models")
|
||||
|
||||
// Simulate packaged/systemd environment
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
t.Setenv("HOME", "/usr/share/ollama")
|
||||
|
||||
// Manifest dir must respect OLLAMA_MODELS
|
||||
wantManifest := filepath.Join(modelsDir, "manifests")
|
||||
if got := DefaultManifestDir(); got != wantManifest {
|
||||
t.Fatalf("DefaultManifestDir() = %q, want %q", got, wantManifest)
|
||||
}
|
||||
|
||||
// Blob dir must respect OLLAMA_MODELS
|
||||
wantBlobs := filepath.Join(modelsDir, "blobs")
|
||||
if got := DefaultBlobDir(); got != wantBlobs {
|
||||
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
|
||||
}
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping qwen_image tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestPipelineOutput runs the full pipeline (integration test).
|
||||
// Skips if model weights not found. Requires ~50GB VRAM.
|
||||
func TestPipelineOutput(t *testing.T) {
|
||||
modelPath := "../../../weights/Qwen-Image-2512"
|
||||
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + modelPath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
pm, err := LoadPersistent(modelPath)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: failed to load model: %v", err)
|
||||
}
|
||||
|
||||
// Run 2-step pipeline (minimum for stable scheduler)
|
||||
cfg := &GenerateConfig{
|
||||
Prompt: "a cat",
|
||||
Width: 256,
|
||||
Height: 256,
|
||||
Steps: 2,
|
||||
Seed: 42,
|
||||
}
|
||||
|
||||
output, err := pm.GenerateFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Pipeline failed: %v", err)
|
||||
}
|
||||
mlx.Eval(output)
|
||||
|
||||
// Verify output shape [1, C, H, W]
|
||||
shape := output.Shape()
|
||||
if len(shape) != 4 {
|
||||
t.Errorf("Expected 4D output, got %v", shape)
|
||||
}
|
||||
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
|
||||
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
|
||||
}
|
||||
|
||||
// Verify values in expected range [0, 1]
|
||||
data := output.Data()
|
||||
minVal, maxVal := float32(1.0), float32(0.0)
|
||||
for _, v := range data {
|
||||
if v < minVal {
|
||||
minVal = v
|
||||
}
|
||||
if v > maxVal {
|
||||
maxVal = v
|
||||
}
|
||||
}
|
||||
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
|
||||
|
||||
if minVal < -0.1 || maxVal > 1.1 {
|
||||
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,367 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen_image implements the Qwen-Image diffusion transformer model.
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
|
||||
// Layer caching (DeepCache/Learning-to-Cache speedup)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 25)
|
||||
}
|
||||
|
||||
// Model represents a Qwen-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *Qwen25VL
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the Qwen-Image model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Qwen-Image model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
||||
tok, err := tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
|
||||
m.TextEncoder = &Qwen25VL{}
|
||||
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
peak := mlx.MetalGetPeakMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
||||
time.Since(start).Seconds(),
|
||||
float64(mem)/(1024*1024*1024),
|
||||
float64(peak)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
CFGScale: cfgScale,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.Config
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
pH := latentH / tcfg.PatchSize
|
||||
pW := latentW / tcfg.PatchSize
|
||||
imgSeqLen := pH * pW
|
||||
|
||||
// Text encoding
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
||||
if useCFG {
|
||||
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
} else {
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
}
|
||||
}
|
||||
|
||||
// Pad sequences to same length for CFG
|
||||
txtLen := posEmb.Shape()[1]
|
||||
if useCFG {
|
||||
negLen := negEmb.Shape()[1]
|
||||
if negLen > txtLen {
|
||||
txtLen = negLen
|
||||
}
|
||||
if posEmb.Shape()[1] < txtLen {
|
||||
posEmb = padSequence(posEmb, txtLen)
|
||||
}
|
||||
if negEmb.Shape()[1] < txtLen {
|
||||
negEmb = padSequence(negEmb, txtLen)
|
||||
}
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
||||
|
||||
// Init latents [B, C, T, H, W]
|
||||
var latents *mlx.Array
|
||||
{
|
||||
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
}
|
||||
|
||||
// RoPE cache
|
||||
var ropeCache *RoPECache
|
||||
{
|
||||
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
|
||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(ropeCache.ImgFreqs)
|
||||
}
|
||||
|
||||
// Layer cache for DeepCache/Learning-to-Cache speedup
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
t := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
||||
|
||||
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
|
||||
latents2D := mlx.Squeeze(latents, 2)
|
||||
patches := PackLatents(latents2D, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
||||
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
// Single batched forward pass
|
||||
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
L := batchedOutput.Shape()[1]
|
||||
D := batchedOutput.Shape()[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
||||
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
combPred := mlx.Add(negOutput, scaledDiff)
|
||||
|
||||
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
|
||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
|
||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
||||
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
||||
} else if stepCache != nil {
|
||||
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
|
||||
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
|
||||
} else {
|
||||
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
}
|
||||
|
||||
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep cached arrays alive across cleanup
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
||||
}
|
||||
|
||||
// Free denoising temporaries before VAE decode
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode (Decode manages its own pools for staged memory)
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
||||
{
|
||||
decoded = mlx.Squeeze(decoded, 2)
|
||||
decoded = mlx.AddScalar(decoded, 1.0)
|
||||
decoded = mlx.DivScalar(decoded, 2.0)
|
||||
mlx.Eval(decoded)
|
||||
}
|
||||
|
||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// padSequence pads a sequence tensor to the target length with zeros
|
||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
// Pad on sequence dimension (axis 1)
|
||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
||||
}
|
||||
|
||||
// LoadPersistent is an alias for backward compatibility.
|
||||
// Use m := &Model{}; m.Load(path) instead.
|
||||
func LoadPersistent(modelPath string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
|
||||
type SchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
BaseShift float32 `json:"base_shift"` // 0.5
|
||||
MaxShift float32 `json:"max_shift"` // 0.9
|
||||
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
||||
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
|
||||
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
|
||||
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
|
||||
func DefaultSchedulerConfig() *SchedulerConfig {
|
||||
return &SchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
BaseShift: 0.5,
|
||||
MaxShift: 0.9, // Matches scheduler_config.json
|
||||
BaseImageSeqLen: 256,
|
||||
MaxImageSeqLen: 8192,
|
||||
ShiftTerminal: 0.02,
|
||||
UseDynamicShift: true,
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *SchedulerConfig
|
||||
Timesteps []float32
|
||||
Sigmas []float32
|
||||
NumSteps int
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateShift computes the dynamic shift based on image sequence length
|
||||
// This matches Python's calculate_shift function
|
||||
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
|
||||
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
|
||||
b := baseShift - m*float32(baseSeqLen)
|
||||
mu := float32(imageSeqLen)*m + b
|
||||
return mu
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
|
||||
// 1. Create sigmas from sigma_max to sigma_min (linspace)
|
||||
// 2. Apply time_shift with mu (if dynamic shifting)
|
||||
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
|
||||
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Calculate mu for dynamic shifting
|
||||
var mu float32
|
||||
if s.Config.UseDynamicShift {
|
||||
mu = CalculateShift(
|
||||
imageSeqLen,
|
||||
s.Config.BaseImageSeqLen,
|
||||
s.Config.MaxImageSeqLen,
|
||||
s.Config.BaseShift,
|
||||
s.Config.MaxShift,
|
||||
)
|
||||
}
|
||||
|
||||
// Step 1: Create sigmas from 1.0 to 1/num_steps
|
||||
// Python (pipeline_qwenimage.py:639):
|
||||
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
|
||||
sigmas := make([]float32, numSteps)
|
||||
sigmaMax := float32(1.0)
|
||||
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
|
||||
if numSteps == 1 {
|
||||
sigmas[0] = sigmaMax
|
||||
} else {
|
||||
for i := 0; i < numSteps; i++ {
|
||||
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShift && mu != 0 {
|
||||
for i := range sigmas {
|
||||
sigmas[i] = s.timeShift(mu, sigmas[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Apply stretch_shift_to_terminal
|
||||
if s.Config.ShiftTerminal > 0 {
|
||||
sigmas = s.stretchShiftToTerminal(sigmas)
|
||||
}
|
||||
|
||||
// Step 4: Append terminal sigma (0) and store
|
||||
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
|
||||
// before passing to transformer. We skip both steps and just use sigmas directly.
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
for i := 0; i < numSteps; i++ {
|
||||
s.Sigmas[i] = sigmas[i]
|
||||
s.Timesteps[i] = sigmas[i]
|
||||
}
|
||||
s.Sigmas[numSteps] = 0.0
|
||||
s.Timesteps[numSteps] = 0.0
|
||||
}
|
||||
|
||||
// stretchShiftToTerminal stretches and shifts the timestep schedule
|
||||
// so the final value equals shift_terminal (matches Python behavior)
|
||||
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
|
||||
if len(sigmas) == 0 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
// one_minus_z = 1 - t
|
||||
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||
// stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
lastSigma := sigmas[len(sigmas)-1]
|
||||
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
|
||||
|
||||
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
|
||||
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
|
||||
if scaleFactor < 1e-6 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
result := make([]float32, len(sigmas))
|
||||
for i, t := range sigmas {
|
||||
oneMinusZ := 1.0 - t
|
||||
result[i] = 1.0 - (oneMinusZ / scaleFactor)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift (exponential)
|
||||
// exp(mu) / (exp(mu) + (1/t - 1))
|
||||
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
// modelOutput: predicted velocity from the transformer
|
||||
// sample: current noisy sample
|
||||
// timestepIdx: current timestep index
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Get current and next sigma
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
dt := sigmaNext - sigma
|
||||
|
||||
// Upcast to float32 to avoid precision issues (matches Python diffusers)
|
||||
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
|
||||
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
|
||||
|
||||
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
|
||||
result := mlx.Add(sampleF32, scaledOutput)
|
||||
|
||||
// Cast back to original dtype
|
||||
return mlx.ToBFloat16(result)
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
|
||||
// This matches how Python diffusers generates noise - directly in packed space.
|
||||
// Generating in unpacked format and then packing produces different spatial
|
||||
// correlation structure, which affects model output quality.
|
||||
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
|
||||
shape := []int32{batchSize, seqLen, channels}
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// GetLatentShape returns the latent shape for a given image size
|
||||
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
|
||||
func GetLatentShape(batchSize, height, width int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
|
||||
}
|
||||
|
||||
// GetPatchedLatentShape returns the patchified latent shape
|
||||
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
|
||||
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
pH := latentH / patchSize
|
||||
pW := latentW / patchSize
|
||||
inChannels := int32(64) // 16 * patch_size^2
|
||||
return []int32{batchSize, pH * pW, inChannels}
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
|
||||
// Golden values generated via:
|
||||
//
|
||||
// python3 -c "
|
||||
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
// import numpy as np
|
||||
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
|
||||
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
|
||||
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
|
||||
// sigmas = np.linspace(1.0, 1.0/30, 30)
|
||||
// s.set_timesteps(sigmas=sigmas, mu=mu)
|
||||
// print(s.sigmas.numpy())"
|
||||
func TestSchedulerSetTimesteps(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Golden values from Python diffusers (first 3, last 3 before terminal)
|
||||
wantFirst := []float32{1.000000, 0.982251, 0.963889}
|
||||
wantLast := []float32{0.142924, 0.083384, 0.020000}
|
||||
|
||||
// Check first 3
|
||||
for i, want := range wantFirst {
|
||||
got := scheduler.Sigmas[i]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check last 3 (indices 27, 28, 29)
|
||||
for i, want := range wantLast {
|
||||
idx := 27 + i
|
||||
got := scheduler.Sigmas[idx]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check terminal is 0
|
||||
if scheduler.Sigmas[30] != 0.0 {
|
||||
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(scheduler.Sigmas) != 31 {
|
||||
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerProperties tests mathematical invariants of the scheduler.
|
||||
func TestSchedulerProperties(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Property: sigmas monotonically decreasing
|
||||
for i := 1; i < len(scheduler.Sigmas); i++ {
|
||||
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
|
||||
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
|
||||
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
|
||||
}
|
||||
}
|
||||
|
||||
// Property: first sigma should be ~1.0 (with time shift)
|
||||
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
|
||||
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
|
||||
}
|
||||
|
||||
// Property: terminal sigma should be exactly 0
|
||||
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
|
||||
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
|
||||
}
|
||||
|
||||
// Property: last non-terminal sigma should be shift_terminal (0.02)
|
||||
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
|
||||
if abs32(lastNonTerminal-0.02) > 1e-5 {
|
||||
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
|
||||
}
|
||||
|
||||
// Property: length = steps + 1
|
||||
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
|
||||
t.Errorf("sigmas length should be steps+1: got %d, want %d",
|
||||
len(scheduler.Sigmas), scheduler.NumSteps+1)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateShift verifies the mu calculation against Python reference.
|
||||
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
func TestCalculateShift(t *testing.T) {
|
||||
cases := []struct {
|
||||
imgSeqLen int32
|
||||
want float32
|
||||
}{
|
||||
{256, 0.5}, // base case
|
||||
{8192, 0.9}, // max case
|
||||
{4096, 0.6935}, // middle case (rounded)
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
|
||||
if abs32(got-c.want) > 0.001 {
|
||||
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerStep verifies the Euler step formula.
|
||||
func TestSchedulerStep(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Verify dt calculation for first step
|
||||
sigma0 := scheduler.Sigmas[0]
|
||||
sigma1 := scheduler.Sigmas[1]
|
||||
expectedDt := sigma1 - sigma0
|
||||
|
||||
// dt should be negative (sigmas decrease)
|
||||
if expectedDt >= 0 {
|
||||
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
|
||||
}
|
||||
}
|
||||
|
||||
func abs32(x float32) float32 {
|
||||
return float32(math.Abs(float64(x)))
|
||||
}
|
||||
@@ -1,174 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TinyTextEncoderConfig holds config for the tiny test text encoder
|
||||
type TinyTextEncoderConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
}
|
||||
|
||||
// loadTinyTextEncoder loads the tiny text encoder from testdata
|
||||
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
|
||||
t.Helper()
|
||||
|
||||
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
|
||||
|
||||
// Load config
|
||||
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
|
||||
}
|
||||
|
||||
var tinyCfg TinyTextEncoderConfig
|
||||
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
|
||||
t.Fatalf("Failed to parse config: %v", err)
|
||||
}
|
||||
|
||||
// Create encoder config (using Qwen25VLConfig)
|
||||
cfg := &Qwen25VLConfig{
|
||||
HiddenSize: tinyCfg.HiddenSize,
|
||||
NumHiddenLayers: tinyCfg.NumHiddenLayers,
|
||||
IntermediateSize: tinyCfg.IntermediateSize,
|
||||
NumAttentionHeads: tinyCfg.NumAttentionHeads,
|
||||
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
|
||||
VocabSize: tinyCfg.VocabSize,
|
||||
RMSNormEps: tinyCfg.RMSNormEps,
|
||||
RopeTheta: tinyCfg.RopeTheta,
|
||||
HeadDim: tinyCfg.HeadDim,
|
||||
MRoPESection: tinyCfg.MRoPESection,
|
||||
}
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(testdataDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load weights: %v", err)
|
||||
}
|
||||
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
t.Fatalf("Failed to bulk load weights: %v", err)
|
||||
}
|
||||
|
||||
// Build encoder
|
||||
embedding, err := weights.Get("model.embed_tokens.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get embedding: %v", err)
|
||||
}
|
||||
|
||||
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
block, err := newVLTextBlock(weights, int(i), cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load block %d: %v", i, err)
|
||||
}
|
||||
blocks[i] = block
|
||||
}
|
||||
|
||||
finalNorm, err := weights.Get("model.norm.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get final norm: %v", err)
|
||||
}
|
||||
|
||||
encoder := &Qwen25VL{
|
||||
Config: cfg,
|
||||
Embedding: embedding,
|
||||
Blocks: blocks,
|
||||
FinalNorm: finalNorm,
|
||||
HasVision: false, // Text-only mode
|
||||
}
|
||||
|
||||
return encoder, &tinyCfg
|
||||
}
|
||||
|
||||
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
|
||||
func TestTextEncoderForward(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// Create test tokens (within vocab range)
|
||||
tokens := []int32{1, 2, 3, 4, 5}
|
||||
|
||||
// Forward pass using EncodeTextOnly
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [batch, seq_len, hidden_size]
|
||||
wantShape := []int32{1, 5, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify output is finite (not NaN or Inf)
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
||||
t.Errorf("output[%d] is not finite: %v", i, v)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextEncoderBatch tests batch processing.
|
||||
func TestTextEncoderBatch(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// For batch test, we'll use EncodeTextOnly with a single sequence
|
||||
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
|
||||
tokens := []int32{1, 2, 3}
|
||||
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
wantShape := []int32{1, 3, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
|
||||
func TestMRoPEComputation(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
cossin := encoder.computeTextRoPE(10, 1)
|
||||
mlx.Eval(cossin[0], cossin[1])
|
||||
|
||||
// Verify shapes: [3, B, L, head_dim]
|
||||
wantShape := []int32{3, 1, 10, cfg.HeadDim}
|
||||
if !slices.Equal(cossin[0].Shape(), wantShape) {
|
||||
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
|
||||
}
|
||||
if !slices.Equal(cossin[1].Shape(), wantShape) {
|
||||
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify cos/sin values are in valid range [-1, 1]
|
||||
cosData := cossin[0].Data()
|
||||
sinData := cossin[1].Data()
|
||||
for i := 0; i < min(100, len(cosData)); i++ {
|
||||
if cosData[i] < -1.01 || cosData[i] > 1.01 {
|
||||
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
|
||||
}
|
||||
if sinData[i] < -1.01 || sinData[i] > 1.01 {
|
||||
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,868 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Qwen-Image transformer configuration
|
||||
type TransformerConfig struct {
|
||||
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
|
||||
NHeads int32 `json:"num_attention_heads"` // 24
|
||||
HeadDim int32 `json:"attention_head_dim"` // 128
|
||||
NLayers int32 `json:"num_layers"` // 60
|
||||
InChannels int32 `json:"in_channels"` // 64
|
||||
OutChannels int32 `json:"out_channels"` // 16
|
||||
PatchSize int32 `json:"patch_size"` // 2
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
|
||||
NormEps float32 `json:"norm_eps"` // 1e-6
|
||||
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false
|
||||
}
|
||||
|
||||
// defaultTransformerConfig returns config for Qwen-Image transformer
|
||||
func defaultTransformerConfig() *TransformerConfig {
|
||||
return &TransformerConfig{
|
||||
HiddenDim: 3072, // 24 * 128
|
||||
NHeads: 24,
|
||||
HeadDim: 128,
|
||||
NLayers: 60,
|
||||
InChannels: 64,
|
||||
OutChannels: 16,
|
||||
PatchSize: 2,
|
||||
JointAttentionDim: 3584,
|
||||
NormEps: 1e-6,
|
||||
AxesDimsRope: []int32{16, 56, 56},
|
||||
GuidanceEmbeds: false,
|
||||
}
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates timestep embeddings
|
||||
type TimestepEmbedder struct {
|
||||
Linear1Weight *mlx.Array // [256, hidden_dim]
|
||||
Linear1Bias *mlx.Array
|
||||
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
|
||||
Linear2Bias *mlx.Array
|
||||
}
|
||||
|
||||
// newTimestepEmbedder creates a timestep embedder from weights
|
||||
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
|
||||
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TimestepEmbedder{
|
||||
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
|
||||
Linear1Bias: linear1Bias,
|
||||
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
|
||||
Linear2Bias: linear2Bias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings
|
||||
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
|
||||
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
half := int32(128) // embedding_dim / 2
|
||||
|
||||
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
tExpanded := mlx.ExpandDims(t, 1)
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
args = mlx.MulScalar(args, 1000.0) // scale
|
||||
|
||||
// [cos, sin] (flip_sin_to_cos=True)
|
||||
sinArgs := mlx.Sin(args)
|
||||
cosArgs := mlx.Cos(args)
|
||||
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
|
||||
|
||||
// MLP: linear1 -> silu -> linear2
|
||||
h := mlx.Linear(embedding, te.Linear1Weight)
|
||||
h = mlx.Add(h, te.Linear1Bias)
|
||||
h = mlx.SiLU(h)
|
||||
h = mlx.Linear(h, te.Linear2Weight)
|
||||
h = mlx.Add(h, te.Linear2Bias)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// JointAttention implements dual-stream joint attention
|
||||
type JointAttention struct {
|
||||
// Image projections
|
||||
ToQ *mlx.Array
|
||||
ToQB *mlx.Array
|
||||
ToK *mlx.Array
|
||||
ToKB *mlx.Array
|
||||
ToV *mlx.Array
|
||||
ToVB *mlx.Array
|
||||
ToOut *mlx.Array
|
||||
ToOutB *mlx.Array
|
||||
NormQ *mlx.Array
|
||||
NormK *mlx.Array
|
||||
|
||||
// Text (added) projections
|
||||
AddQProj *mlx.Array
|
||||
AddQProjB *mlx.Array
|
||||
AddKProj *mlx.Array
|
||||
AddKProjB *mlx.Array
|
||||
AddVProj *mlx.Array
|
||||
AddVProjB *mlx.Array
|
||||
ToAddOut *mlx.Array
|
||||
ToAddOutB *mlx.Array
|
||||
NormAddQ *mlx.Array
|
||||
NormAddK *mlx.Array
|
||||
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// newJointAttention creates a joint attention layer
|
||||
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
|
||||
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
|
||||
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
|
||||
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
|
||||
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
|
||||
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
|
||||
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
|
||||
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
|
||||
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
|
||||
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
|
||||
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
|
||||
|
||||
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
|
||||
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
|
||||
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
|
||||
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
|
||||
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
|
||||
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
|
||||
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
|
||||
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
|
||||
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
|
||||
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
|
||||
|
||||
return &JointAttention{
|
||||
ToQ: mlx.Transpose(toQ, 1, 0),
|
||||
ToQB: toQB,
|
||||
ToK: mlx.Transpose(toK, 1, 0),
|
||||
ToKB: toKB,
|
||||
ToV: mlx.Transpose(toV, 1, 0),
|
||||
ToVB: toVB,
|
||||
ToOut: mlx.Transpose(toOut, 1, 0),
|
||||
ToOutB: toOutB,
|
||||
NormQ: normQ,
|
||||
NormK: normK,
|
||||
AddQProj: mlx.Transpose(addQProj, 1, 0),
|
||||
AddQProjB: addQProjB,
|
||||
AddKProj: mlx.Transpose(addKProj, 1, 0),
|
||||
AddKProjB: addKProjB,
|
||||
AddVProj: mlx.Transpose(addVProj, 1, 0),
|
||||
AddVProjB: addVProjB,
|
||||
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
|
||||
ToAddOutB: toAddOutB,
|
||||
NormAddQ: normAddQ,
|
||||
NormAddK: normAddK,
|
||||
NHeads: cfg.NHeads,
|
||||
HeadDim: cfg.HeadDim,
|
||||
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes joint attention
|
||||
// img: [B, L_img, D], txt: [B, L_txt, D]
|
||||
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
|
||||
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
D := imgShape[2]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// === Image Q/K/V ===
|
||||
imgFlat := mlx.Reshape(img, B*Limg, D)
|
||||
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
|
||||
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
|
||||
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
|
||||
|
||||
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
|
||||
// QK norm (RMSNorm per head)
|
||||
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
|
||||
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
|
||||
|
||||
// Apply RoPE
|
||||
if imgFreqs != nil {
|
||||
qImg = applyRoPE(qImg, imgFreqs)
|
||||
kImg = applyRoPE(kImg, imgFreqs)
|
||||
}
|
||||
|
||||
// === Text Q/K/V ===
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
|
||||
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
|
||||
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
|
||||
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
|
||||
|
||||
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
|
||||
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
|
||||
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
|
||||
|
||||
if txtFreqs != nil {
|
||||
qTxt = applyRoPE(qTxt, txtFreqs)
|
||||
kTxt = applyRoPE(kTxt, txtFreqs)
|
||||
}
|
||||
|
||||
// Concatenate for joint attention: [txt, img] order
|
||||
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
|
||||
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
|
||||
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
|
||||
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
|
||||
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
|
||||
|
||||
// Transpose back and split
|
||||
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
|
||||
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
|
||||
|
||||
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
|
||||
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
|
||||
|
||||
// Output projections
|
||||
outImg = mlx.Reshape(outImg, B*Limg, D)
|
||||
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
|
||||
outImg = mlx.Reshape(outImg, B, Limg, D)
|
||||
|
||||
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
|
||||
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
|
||||
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
|
||||
|
||||
return outImg, outTxt
|
||||
}
|
||||
|
||||
// applyRoPE applies rotary embeddings using complex multiplication
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
|
||||
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
halfDim := headDim / 2
|
||||
|
||||
// Reshape x to pairs: [B, L, nheads, half, 2]
|
||||
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
|
||||
|
||||
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
|
||||
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
|
||||
|
||||
// Extract real/imag parts
|
||||
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
xReal = mlx.Squeeze(xReal, 4)
|
||||
xImag = mlx.Squeeze(xImag, 4)
|
||||
|
||||
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
freqReal = mlx.Squeeze(freqReal, 4)
|
||||
freqImag = mlx.Squeeze(freqImag, 4)
|
||||
|
||||
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
|
||||
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
|
||||
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
|
||||
|
||||
// Interleave back
|
||||
outReal = mlx.ExpandDims(outReal, 4)
|
||||
outImag = mlx.ExpandDims(outImag, 4)
|
||||
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
|
||||
|
||||
return mlx.Reshape(out, B, L, nheads, headDim)
|
||||
}
|
||||
|
||||
// MLP implements GELU MLP (not GEGLU)
|
||||
type MLP struct {
|
||||
ProjWeight *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
OutWeight *mlx.Array
|
||||
OutBias *mlx.Array
|
||||
}
|
||||
|
||||
// newMLP creates a GELU MLP
|
||||
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
|
||||
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
|
||||
outWeight, _ := weights.Get(prefix + ".net.2.weight")
|
||||
outBias, _ := weights.Get(prefix + ".net.2.bias")
|
||||
|
||||
return &MLP{
|
||||
ProjWeight: mlx.Transpose(projWeight, 1, 0),
|
||||
ProjBias: projBias,
|
||||
OutWeight: mlx.Transpose(outWeight, 1, 0),
|
||||
OutBias: outBias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies GELU MLP
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
|
||||
h = geluApprox(h)
|
||||
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
|
||||
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
|
||||
}
|
||||
|
||||
// geluApprox implements approximate GELU
|
||||
func geluApprox(x *mlx.Array) *mlx.Array {
|
||||
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
|
||||
inner = mlx.MulScalar(inner, sqrt2OverPi)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
||||
}
|
||||
|
||||
// TransformerBlock is a single dual-stream transformer block
|
||||
type TransformerBlock struct {
|
||||
Attention *JointAttention
|
||||
ImgMLP *MLP
|
||||
TxtMLP *MLP
|
||||
|
||||
ImgModWeight *mlx.Array
|
||||
ImgModBias *mlx.Array
|
||||
TxtModWeight *mlx.Array
|
||||
TxtModBias *mlx.Array
|
||||
|
||||
HiddenDim int32
|
||||
NormEps float32
|
||||
}
|
||||
|
||||
// newTransformerBlock creates a transformer block
|
||||
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
|
||||
attn, err := newJointAttention(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
|
||||
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
|
||||
|
||||
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
|
||||
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
|
||||
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
|
||||
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
|
||||
|
||||
return &TransformerBlock{
|
||||
Attention: attn,
|
||||
ImgMLP: imgMLP,
|
||||
TxtMLP: txtMLP,
|
||||
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
|
||||
ImgModBias: imgModBias,
|
||||
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
|
||||
TxtModBias: txtModBias,
|
||||
HiddenDim: cfg.HiddenDim,
|
||||
NormEps: cfg.NormEps,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the transformer block
|
||||
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
|
||||
siluT := mlx.SiLU(temb)
|
||||
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
|
||||
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
|
||||
|
||||
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
|
||||
imgModParts := splitMod6(imgMod, tb.HiddenDim)
|
||||
txtModParts := splitMod6(txtMod, tb.HiddenDim)
|
||||
|
||||
// Pre-attention: norm + modulate
|
||||
imgNorm := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
|
||||
|
||||
txtNorm := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
|
||||
|
||||
// Joint attention
|
||||
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
|
||||
|
||||
// Pre-MLP: norm + modulate
|
||||
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
|
||||
|
||||
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
|
||||
|
||||
// MLP
|
||||
mlpImg := tb.ImgMLP.Forward(imgNorm2)
|
||||
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
|
||||
|
||||
return img, txt
|
||||
}
|
||||
|
||||
// splitMod6 splits modulation into 6 parts each [B, 1, D]
|
||||
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
|
||||
shape := mod.Shape()
|
||||
B := shape[0]
|
||||
parts := make([]*mlx.Array, 6)
|
||||
for i := int32(0); i < 6; i++ {
|
||||
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
|
||||
parts[i] = mlx.ExpandDims(part, 1)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
|
||||
// Transformer is the full Qwen-Image transformer model
|
||||
type Transformer struct {
|
||||
Config *TransformerConfig
|
||||
|
||||
ImgIn *mlx.Array
|
||||
ImgInBias *mlx.Array
|
||||
TxtIn *mlx.Array
|
||||
TxtInBias *mlx.Array
|
||||
TxtNorm *mlx.Array
|
||||
|
||||
TEmbed *TimestepEmbedder
|
||||
Layers []*TransformerBlock
|
||||
|
||||
NormOutWeight *mlx.Array
|
||||
NormOutBias *mlx.Array
|
||||
ProjOut *mlx.Array
|
||||
ProjOutBias *mlx.Array
|
||||
}
|
||||
|
||||
// Load loads the transformer from a directory
|
||||
func (m *Transformer) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image transformer...")
|
||||
|
||||
cfg := defaultTransformerConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Bulk load all weights as bf16
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading input projections... ")
|
||||
imgIn, _ := weights.Get("img_in.weight")
|
||||
imgInBias, _ := weights.Get("img_in.bias")
|
||||
txtIn, _ := weights.Get("txt_in.weight")
|
||||
txtInBias, _ := weights.Get("txt_in.bias")
|
||||
txtNorm, _ := weights.Get("txt_norm.weight")
|
||||
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
|
||||
m.ImgInBias = imgInBias
|
||||
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
|
||||
m.TxtInBias = txtInBias
|
||||
m.TxtNorm = txtNorm
|
||||
fmt.Println("✓")
|
||||
|
||||
fmt.Print(" Loading timestep embedder... ")
|
||||
m.TEmbed, err = newTimestepEmbedder(weights)
|
||||
if err != nil {
|
||||
return fmt.Errorf("timestep embedder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
for i := int32(0); i < cfg.NLayers; i++ {
|
||||
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
|
||||
prefix := fmt.Sprintf("transformer_blocks.%d", i)
|
||||
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("layer %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
|
||||
|
||||
fmt.Print(" Loading output layers... ")
|
||||
normOutWeight, _ := weights.Get("norm_out.linear.weight")
|
||||
normOutBias, _ := weights.Get("norm_out.linear.bias")
|
||||
projOut, _ := weights.Get("proj_out.weight")
|
||||
projOutBias, _ := weights.Get("proj_out.bias")
|
||||
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
|
||||
m.NormOutBias = normOutBias
|
||||
m.ProjOut = mlx.Transpose(projOut, 1, 0)
|
||||
m.ProjOutBias = projOutBias
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath is a convenience function to load transformer from path
|
||||
func LoadTransformerFromPath(path string) (*Transformer, error) {
|
||||
m := &Transformer{}
|
||||
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the transformer
|
||||
// img: [B, L_img, in_channels] patchified latents
|
||||
// txt: [B, L_txt, joint_attention_dim] text embeddings
|
||||
// t: [B] timesteps (0-1)
|
||||
// imgFreqs, txtFreqs: RoPE frequencies
|
||||
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
for _, layer := range tr.Layers {
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// ForwardWithCache runs the transformer with layer caching for speedup.
|
||||
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
|
||||
// shallow layers change little between denoising steps, so we cache their
|
||||
// outputs and reuse them on non-refresh steps.
|
||||
//
|
||||
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
|
||||
// step: current denoising step (0-indexed)
|
||||
// cacheInterval: refresh cache every N steps (e.g., 3)
|
||||
// cacheLayers: number of shallow layers to cache (e.g., 15)
|
||||
func (tr *Transformer) ForwardWithCache(
|
||||
img, txt, t *mlx.Array,
|
||||
imgFreqs, txtFreqs *mlx.Array,
|
||||
stepCache *cache.StepCache,
|
||||
step, cacheInterval, cacheLayers int,
|
||||
) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
// Check if we should refresh the cache
|
||||
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
|
||||
|
||||
for i, layer := range tr.Layers {
|
||||
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
|
||||
// Use cached outputs for shallow layers
|
||||
imgH = stepCache.Get(i)
|
||||
txtH = stepCache.Get2(i)
|
||||
} else {
|
||||
// Compute layer
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
// Cache shallow layers on refresh steps
|
||||
if i < cacheLayers && refreshCache {
|
||||
stepCache.Set(i, imgH)
|
||||
stepCache.Set2(i, txtH)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE frequencies
|
||||
type RoPECache struct {
|
||||
ImgFreqs *mlx.Array // [L_img, head_dim]
|
||||
TxtFreqs *mlx.Array // [L_txt, head_dim]
|
||||
}
|
||||
|
||||
// PrepareRoPE computes RoPE for image and text sequences
|
||||
// This matches Python's QwenEmbedRope with scale_rope=True
|
||||
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
|
||||
theta := float64(10000)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Compute base frequencies for each axis dimension
|
||||
freqsT := ComputeAxisFreqs(axesDims[0], theta)
|
||||
freqsH := ComputeAxisFreqs(axesDims[1], theta)
|
||||
freqsW := ComputeAxisFreqs(axesDims[2], theta)
|
||||
|
||||
// Build frequency lookup tables
|
||||
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
|
||||
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
|
||||
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
|
||||
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
|
||||
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
|
||||
|
||||
// Image frequencies with scale_rope=True
|
||||
imgLen := imgH * imgW
|
||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
||||
imgFreqsData := make([]float32, imgLen*headDim)
|
||||
|
||||
hHalf := imgH / 2
|
||||
wHalf := imgW / 2
|
||||
|
||||
idx := int32(0)
|
||||
for y := int32(0); y < imgH; y++ {
|
||||
for x := int32(0); x < imgW; x++ {
|
||||
// Frame = 0
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
|
||||
// Height: scale_rope pattern
|
||||
hNegCount := imgH - hHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
|
||||
// Width: scale_rope pattern
|
||||
wNegCount := imgW - wHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
}
|
||||
|
||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
|
||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
||||
|
||||
// Text frequencies
|
||||
maxVidIdx := max(hHalf, wHalf)
|
||||
txtFreqsData := make([]float32, txtLen*headDim)
|
||||
|
||||
idx = 0
|
||||
for t := int32(0); t < txtLen; t++ {
|
||||
pos := maxVidIdx + t
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
|
||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
||||
|
||||
return &RoPECache{
|
||||
ImgFreqs: imgFreqs,
|
||||
TxtFreqs: txtFreqs,
|
||||
}
|
||||
}
|
||||
|
||||
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
|
||||
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
|
||||
halfDim := dim / 2
|
||||
freqs := make([]float64, halfDim)
|
||||
for i := int32(0); i < halfDim; i++ {
|
||||
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
|
||||
}
|
||||
return freqs
|
||||
}
|
||||
|
||||
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
|
||||
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
|
||||
table := make([][]float32, maxIdx)
|
||||
for idx := int32(0); idx < maxIdx; idx++ {
|
||||
var pos float64
|
||||
if negative {
|
||||
pos = float64(-maxIdx + int32(idx))
|
||||
} else {
|
||||
pos = float64(idx)
|
||||
}
|
||||
|
||||
row := make([]float32, len(baseFreqs)*2)
|
||||
for i, f := range baseFreqs {
|
||||
angle := pos * f
|
||||
row[i*2] = float32(math.Cos(angle))
|
||||
row[i*2+1] = float32(math.Sin(angle))
|
||||
}
|
||||
table[idx] = row
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func max(a, b int32) int32 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
|
||||
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
||||
// -> [B, pH, pW, C, 2, 2]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// -> [B, pH*pW, C*4]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
|
||||
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
channels := shape[2] / (patchSize * patchSize)
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
|
||||
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
||||
// -> [B, C, pH, 2, pW, 2]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// -> [B, C, H, W]
|
||||
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
||||
// Add temporal dimension for VAE: [B, C, 1, H, W]
|
||||
return mlx.ExpandDims(x, 2)
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestTransformerConfig tests configuration invariants.
|
||||
func TestTransformerConfig(t *testing.T) {
|
||||
cfg := defaultTransformerConfig()
|
||||
|
||||
// Property: hidden_dim = n_heads * head_dim
|
||||
if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
|
||||
t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
|
||||
cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
|
||||
}
|
||||
|
||||
// Property: axes_dims_rope sums to head_dim
|
||||
var ropeSum int32
|
||||
for _, d := range cfg.AxesDimsRope {
|
||||
ropeSum += d
|
||||
}
|
||||
if ropeSum != cfg.HeadDim {
|
||||
t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
|
||||
}
|
||||
|
||||
// Property: in_channels = out_channels * patch_size^2
|
||||
expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
|
||||
if cfg.InChannels != expectedIn {
|
||||
t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransformerRoPE tests RoPE frequency computation produces valid values.
|
||||
func TestTransformerRoPE(t *testing.T) {
|
||||
cfg := defaultTransformerConfig()
|
||||
|
||||
// Test with small image dimensions
|
||||
imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
|
||||
txtLen := int32(5)
|
||||
|
||||
ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
|
||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Verify shapes: [seq_len, head_dim]
|
||||
imgSeqLen := imgH * imgW
|
||||
if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
|
||||
t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
|
||||
}
|
||||
if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
|
||||
t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
|
||||
}
|
||||
|
||||
if ropeCache.TxtFreqs.Shape()[0] != txtLen {
|
||||
t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
|
||||
}
|
||||
|
||||
// Verify values are finite
|
||||
imgData := ropeCache.ImgFreqs.Data()
|
||||
for i := 0; i < min(100, len(imgData)); i++ {
|
||||
if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
|
||||
t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransformerForward tests full forward pass (integration test).
|
||||
// Skips if model weights are not available.
|
||||
func TestTransformerForward(t *testing.T) {
|
||||
weightsPath := "../../../weights/Qwen-Image-2512/transformer"
|
||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
||||
}
|
||||
|
||||
transformer := &Transformer{}
|
||||
if err := transformer.Load(weightsPath); err != nil {
|
||||
t.Fatalf("Failed to load transformer: %v", err)
|
||||
}
|
||||
mlx.Keep(mlx.Collect(transformer)...)
|
||||
cfg := transformer.Config
|
||||
|
||||
// Small test inputs
|
||||
batchSize := int32(1)
|
||||
imgH, imgW := int32(4), int32(4)
|
||||
imgSeqLen := imgH * imgW
|
||||
txtSeqLen := int32(5)
|
||||
|
||||
hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
|
||||
encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
|
||||
timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
|
||||
|
||||
ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
|
||||
|
||||
// Forward pass
|
||||
out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [batch, img_seq_len, in_channels]
|
||||
wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
|
||||
gotShape := out.Shape()
|
||||
if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
|
||||
t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
|
||||
}
|
||||
|
||||
// Verify output is finite
|
||||
outData := out.Data()
|
||||
for i := 0; i < min(100, len(outData)); i++ {
|
||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,854 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds Qwen-Image VAE configuration
|
||||
type VAEConfig struct {
|
||||
ZDim int32 `json:"z_dim"` // 16
|
||||
BaseDim int32 `json:"base_dim"` // 96
|
||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
||||
}
|
||||
|
||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
||||
func defaultVAEConfig() *VAEConfig {
|
||||
return &VAEConfig{
|
||||
ZDim: 16,
|
||||
BaseDim: 96,
|
||||
DimMult: []int32{1, 2, 4, 4},
|
||||
NumResBlocks: 2,
|
||||
LatentsMean: []float32{
|
||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632,
|
||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
||||
},
|
||||
LatentsStd: []float32{
|
||||
2.8184, 1.4541, 2.3275, 2.6558,
|
||||
1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579,
|
||||
1.6382, 1.1253, 2.8251, 1.916,
|
||||
},
|
||||
TemperalDownsample: []bool{false, true, true},
|
||||
}
|
||||
}
|
||||
|
||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
||||
type CausalConv3d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
||||
KernelT int32
|
||||
}
|
||||
|
||||
// newCausalConv3d creates a 3D causal conv
|
||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
||||
weight, err := weights.Get(prefix + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
||||
}
|
||||
bias, _ := weights.Get(prefix + ".bias")
|
||||
|
||||
kernelT := weight.Shape()[2]
|
||||
outC := weight.Shape()[0]
|
||||
|
||||
var biasReshaped *mlx.Array
|
||||
if bias != nil {
|
||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
||||
}
|
||||
|
||||
return &CausalConv3d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
BiasReshaped: biasReshaped,
|
||||
KernelT: kernelT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies causal 3D convolution
|
||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
|
||||
kernelT := shape[2]
|
||||
kernelH := shape[3]
|
||||
kernelW := shape[4]
|
||||
|
||||
// Causal temporal padding, same spatial padding
|
||||
// Input is channels-last: [B, T, H, W, C]
|
||||
padT := kernelT - 1
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Stage 1: Pad
|
||||
{
|
||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
var out *mlx.Array
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// RMSNorm3D applies RMS normalization over channels
|
||||
// Works with channels-last [B, T, H, W, C] format
|
||||
type RMSNorm3D struct {
|
||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
||||
}
|
||||
|
||||
// newRMSNorm3D creates an RMS norm
|
||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
||||
gamma, err := weights.Get(prefix + ".gamma")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
|
||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
||||
return &RMSNorm3D{Gamma: gamma}, nil
|
||||
}
|
||||
|
||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
||||
// RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
|
||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
||||
return mlx.Mul(normalized, n.Gamma)
|
||||
}
|
||||
|
||||
// ResBlock is a residual block with RMS norm and causal convs
|
||||
type ResBlock struct {
|
||||
Norm1 *RMSNorm3D
|
||||
Conv1 *CausalConv3d
|
||||
Norm2 *RMSNorm3D
|
||||
Conv2 *CausalConv3d
|
||||
Shortcut *CausalConv3d
|
||||
}
|
||||
|
||||
// newResBlock creates a residual block
|
||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shortcut *CausalConv3d
|
||||
if inDim != outDim {
|
||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ResBlock{
|
||||
Norm1: norm1,
|
||||
Conv1: conv1,
|
||||
Norm2: norm2,
|
||||
Conv2: conv2,
|
||||
Shortcut: shortcut,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the residual block
|
||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Use h as working variable, keep x intact for residual (caller will free x)
|
||||
// Conv handles its own pools, so we just need pools for non-conv operations
|
||||
var h *mlx.Array
|
||||
|
||||
// Keep x so it survives Eval() cleanup - needed for residual connection
|
||||
mlx.Keep(x)
|
||||
|
||||
// Stage 1: norm1 + silu
|
||||
{
|
||||
h = r.Norm1.Forward(x)
|
||||
h = silu3D(h)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: conv1 (handles its own pools)
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 3: norm2 + silu
|
||||
{
|
||||
prev := h
|
||||
h = r.Norm2.Forward(h)
|
||||
h = silu3D(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: conv2 (handles its own pools)
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Residual connection (shortcut handles its own pools if present)
|
||||
if r.Shortcut != nil {
|
||||
shortcut := r.Shortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
mlx.Eval(h)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// AttentionBlock is a 2D attention block
|
||||
type AttentionBlock struct {
|
||||
Norm *RMSNorm3D
|
||||
ToQKV *mlx.Array
|
||||
ToQKVBias *mlx.Array
|
||||
Proj *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// newAttentionBlock creates an attention block
|
||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
||||
|
||||
return &AttentionBlock{
|
||||
Norm: norm,
|
||||
ToQKV: toQKV,
|
||||
ToQKVBias: toQKVBias,
|
||||
Proj: proj,
|
||||
ProjBias: projBias,
|
||||
Dim: dim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies 2D attention
|
||||
// Input: [B, T, H, W, C] (channels-last)
|
||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
identity := x
|
||||
|
||||
// Flatten to [B*T, 1, H, W, C] for norm
|
||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
||||
x = a.Norm.Forward(x)
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Flatten spatial to [B*T, H*W, C]
|
||||
x = mlx.Reshape(x, B*T, H*W, C)
|
||||
|
||||
// Linear to get Q, K, V: [B*T, H*W, 3*C]
|
||||
// Weight is [outC, inC] or [outC, inC, 1, 1]
|
||||
wShape := a.ToQKV.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = a.ToQKV
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0) // [inC, outC]
|
||||
|
||||
qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
|
||||
if a.ToQKVBias != nil {
|
||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
||||
}
|
||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
||||
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
// out: [B*T, 1, H*W, C]
|
||||
out = mlx.Reshape(out, B*T, H*W, C)
|
||||
|
||||
// Project back
|
||||
pShape := a.Proj.Shape()
|
||||
var p *mlx.Array
|
||||
if len(pShape) == 4 {
|
||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
||||
} else {
|
||||
p = a.Proj
|
||||
}
|
||||
p = mlx.Transpose(p, 1, 0) // [inC, outC]
|
||||
out = mlx.Linear(out, p) // [B*T, H*W, C]
|
||||
if a.ProjBias != nil {
|
||||
out = mlx.Add(out, a.ProjBias)
|
||||
}
|
||||
|
||||
out = mlx.Reshape(out, B, T, H, W, C)
|
||||
return mlx.Add(out, identity)
|
||||
}
|
||||
|
||||
// UpBlock handles upsampling in decoder
|
||||
type UpBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Upsampler *Upsample
|
||||
}
|
||||
|
||||
// newUpBlock creates an up block
|
||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var upsampler *Upsample
|
||||
if upsampleMode != "" {
|
||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
||||
}
|
||||
|
||||
return &UpBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Upsampler: upsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies up block with staged memory management
|
||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// ResBlocks handle their own pools
|
||||
for _, block := range u.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Upsampler handles its own pools
|
||||
if u.Upsampler != nil {
|
||||
prev := x
|
||||
x = u.Upsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Upsample handles spatial upsampling
|
||||
type Upsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newUpsample creates an upsampler
|
||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Upsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
||||
// Uses staged pools to reduce peak memory during 2x upsampling
|
||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := u.Conv.Shape()[0]
|
||||
|
||||
// Stage 1: 2x nearest neighbor upsample
|
||||
{
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
x = upsample2xChannelsLast(x)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
||||
if u.Bias != nil {
|
||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// MidBlock is the middle block of decoder
|
||||
type MidBlock struct {
|
||||
ResBlock1 *ResBlock
|
||||
Attention *AttentionBlock
|
||||
ResBlock2 *ResBlock
|
||||
}
|
||||
|
||||
// newMidBlock creates a mid block
|
||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MidBlock{
|
||||
ResBlock1: res1,
|
||||
Attention: attn,
|
||||
ResBlock2: res2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies mid block
|
||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Each component handles its own pools; we just free inputs
|
||||
prev := x
|
||||
x = m.ResBlock1.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.ResBlock2.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the full VAE decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
PostQuantConv *CausalConv3d
|
||||
ConvIn *CausalConv3d
|
||||
MidBlock *MidBlock
|
||||
UpBlocks []*UpBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from a directory
|
||||
func (m *VAEDecoder) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image VAE decoder...")
|
||||
|
||||
cfg := defaultVAEConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Bulk load all weights as bf16
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("failed to load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading post_quant_conv... ")
|
||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.PostQuantConv = postQuantConv
|
||||
fmt.Println("✓")
|
||||
|
||||
fmt.Print(" Loading conv_in... ")
|
||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvIn = convIn
|
||||
fmt.Println("✓")
|
||||
|
||||
// Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
|
||||
fmt.Print(" Loading mid_block... ")
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.MidBlock = midBlock
|
||||
fmt.Println("✓")
|
||||
|
||||
// Up blocks (reversed dim_mult)
|
||||
fmt.Print(" Loading up_blocks... ")
|
||||
numUpBlocks := len(cfg.DimMult)
|
||||
m.UpBlocks = make([]*UpBlock, numUpBlocks)
|
||||
|
||||
dimsMult := make([]int32, numUpBlocks+1)
|
||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
||||
}
|
||||
|
||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
||||
for i := range cfg.TemperalDownsample {
|
||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
||||
}
|
||||
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
inDim := cfg.BaseDim * dimsMult[i]
|
||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
||||
|
||||
if i > 0 {
|
||||
inDim = inDim / 2
|
||||
}
|
||||
|
||||
upsampleMode := ""
|
||||
if i < numUpBlocks-1 {
|
||||
if temporalUpsample[i] {
|
||||
upsampleMode = "upsample3d"
|
||||
} else {
|
||||
upsampleMode = "upsample2d"
|
||||
}
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.UpBlocks[i] = upBlock
|
||||
}
|
||||
fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
|
||||
|
||||
fmt.Print(" Loading output layers... ")
|
||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.NormOut = normOut
|
||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvOut = convOut
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadVAEDecoderFromPath is a convenience function to load VAE from path
|
||||
func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
|
||||
m := &VAEDecoder{}
|
||||
if err := m.Load(filepath.Join(path, "vae")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Decode converts latents to image
|
||||
// z: [B, C, T, H, W] normalized latents
|
||||
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
||||
func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
||||
var x *mlx.Array
|
||||
|
||||
// Stage 1a: Denormalize and transpose
|
||||
{
|
||||
z = vae.Denormalize(z)
|
||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(z)
|
||||
}
|
||||
|
||||
// Stage 1b: PostQuantConv (handles its own pools)
|
||||
x = vae.PostQuantConv.Forward(z)
|
||||
z.Free()
|
||||
|
||||
// Stage 1c: ConvIn (handles its own pools)
|
||||
{
|
||||
prev := x
|
||||
x = vae.ConvIn.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 2: Mid block (handles its own pools)
|
||||
x = vae.MidBlock.Forward(x)
|
||||
|
||||
// Stage 3: Up blocks (each handles its own pools)
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
x = upBlock.Forward(x)
|
||||
}
|
||||
|
||||
// Stage 4a: NormOut + silu
|
||||
{
|
||||
prev := x
|
||||
x = vae.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 4b: ConvOut (handles its own pools)
|
||||
{
|
||||
prev := x
|
||||
x = vae.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 4c: Post-processing
|
||||
{
|
||||
prev := x
|
||||
// Clamp to [-1, 1]
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
// Convert back from channels-last to channels-first
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Denormalize reverses the normalization applied during encoding
|
||||
func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
mean = mlx.ToBFloat16(mean)
|
||||
std = mlx.ToBFloat16(std)
|
||||
|
||||
return mlx.Add(mlx.Mul(z, std), mean)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func silu3D(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
||||
}
|
||||
|
||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
// Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
|
||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
||||
}
|
||||
|
||||
func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
|
||||
}
|
||||
|
||||
func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
||||
x = mlx.Reshape(x, B*H*W, shape[1])
|
||||
|
||||
wShape := weight.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(weight, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = weight
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0)
|
||||
|
||||
out := mlx.Linear(x, w)
|
||||
outC := w.Dim(1)
|
||||
out = mlx.Reshape(out, B, H, W, outC)
|
||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
||||
}
|
||||
|
||||
func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
|
||||
x = pad2D(x, 1, 1, 1, 1)
|
||||
return conv2D(x, weight, 1, 1)
|
||||
}
|
||||
|
||||
func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
|
||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
||||
w = mlx.Transpose(w, 0, 2, 3, 1)
|
||||
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
wShape := w.Shape()
|
||||
Cout := wShape[0]
|
||||
kH := wShape[1]
|
||||
kW := wShape[2]
|
||||
|
||||
outH := (H-kH)/strideH + 1
|
||||
outW := (W-kW)/strideW + 1
|
||||
|
||||
patches := extractPatches2D(x, kH, kW, strideH, strideW)
|
||||
wFlat := mlx.Reshape(w, Cout, -1)
|
||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
out = mlx.Reshape(out, B, outH, outW, Cout)
|
||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
||||
}
|
||||
|
||||
func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
outH := (H-kH)/strideH + 1
|
||||
outW := (W-kW)/strideW + 1
|
||||
|
||||
patches := make([]*mlx.Array, outH*outW)
|
||||
idx := 0
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startH := i * strideH
|
||||
startW := j * strideW
|
||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
||||
}
|
||||
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
x = mlx.Take(x, rowIdx, 2)
|
||||
x = mlx.Take(x, colIdx, 3)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
// Create repeat indices for rows
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
// Create repeat indices for columns
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
// Take along H (axis 1) then W (axis 2)
|
||||
x = mlx.Take(x, rowIdx, 1)
|
||||
x = mlx.Take(x, colIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
||||
// weight: [outC, kH, kW, inC] (MLX channels-last format)
|
||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
||||
// Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
// Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
|
||||
// stride=1, padding=0 (we already padded manually)
|
||||
return mlx.Conv2d(x, weight, 1, 0)
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestVAEConfig tests configuration invariants.
|
||||
func TestVAEConfig(t *testing.T) {
|
||||
cfg := defaultVAEConfig()
|
||||
|
||||
// Property: latents_mean and latents_std have z_dim elements
|
||||
if int32(len(cfg.LatentsMean)) != cfg.ZDim {
|
||||
t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim)
|
||||
}
|
||||
if int32(len(cfg.LatentsStd)) != cfg.ZDim {
|
||||
t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim)
|
||||
}
|
||||
|
||||
// Property: dim_mult defines 4 stages
|
||||
if len(cfg.DimMult) != 4 {
|
||||
t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult))
|
||||
}
|
||||
|
||||
// Property: temperal_downsample has 3 elements (for 3 transitions)
|
||||
if len(cfg.TemperalDownsample) != 3 {
|
||||
t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample))
|
||||
}
|
||||
}
|
||||
|
||||
// TestVAELatentsNormalization tests the latent denormalization values.
|
||||
func TestVAELatentsNormalization(t *testing.T) {
|
||||
cfg := defaultVAEConfig()
|
||||
|
||||
// Verify latents_std values are all positive
|
||||
for i, std := range cfg.LatentsStd {
|
||||
if std <= 0 {
|
||||
t.Errorf("latents_std[%d] should be positive: %v", i, std)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify values are in reasonable range (from actual model)
|
||||
for i, mean := range cfg.LatentsMean {
|
||||
if math.Abs(float64(mean)) > 5 {
|
||||
t.Errorf("latents_mean[%d] seems too large: %v", i, mean)
|
||||
}
|
||||
}
|
||||
for i, std := range cfg.LatentsStd {
|
||||
if std > 10 {
|
||||
t.Errorf("latents_std[%d] seems too large: %v", i, std)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVAEDecoderForward tests full forward pass (integration test).
|
||||
// Skips if model weights are not available.
|
||||
func TestVAEDecoderForward(t *testing.T) {
|
||||
weightsPath := "../../../weights/Qwen-Image-2512/vae"
|
||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
||||
}
|
||||
|
||||
vae := &VAEDecoder{}
|
||||
if err := vae.Load(weightsPath); err != nil {
|
||||
t.Fatalf("Failed to load VAE decoder: %v", err)
|
||||
}
|
||||
mlx.Keep(mlx.Collect(vae)...)
|
||||
|
||||
// Small test input: [B, C, T, H, W]
|
||||
// After 4 upsampling stages (2x each), H/W multiply by 16
|
||||
batchSize := int32(1)
|
||||
channels := int32(16)
|
||||
frames := int32(1)
|
||||
latentH := int32(4)
|
||||
latentW := int32(4)
|
||||
|
||||
latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0)
|
||||
|
||||
// Decode
|
||||
out := vae.Decode(latents)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [B, 3, T, H*16, W*16]
|
||||
outShape := out.Shape()
|
||||
if outShape[0] != batchSize {
|
||||
t.Errorf("batch size: got %d, want %d", outShape[0], batchSize)
|
||||
}
|
||||
if outShape[1] != 3 {
|
||||
t.Errorf("channels: got %d, want 3", outShape[1])
|
||||
}
|
||||
if outShape[2] != frames {
|
||||
t.Errorf("frames: got %d, want %d", outShape[2], frames)
|
||||
}
|
||||
expectedH := latentH * 16 // 4 stages of 2x upsampling
|
||||
expectedW := latentW * 16
|
||||
if outShape[3] != expectedH || outShape[4] != expectedW {
|
||||
t.Errorf("spatial dims: got [%d, %d], want [%d, %d]",
|
||||
outShape[3], outShape[4], expectedH, expectedW)
|
||||
}
|
||||
|
||||
// Verify output is in valid range (should be clamped to [0, 1] by decode)
|
||||
outData := out.Data()
|
||||
for i := 0; i < min(100, len(outData)); i++ {
|
||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,682 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
||||
type CausalConv3d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
||||
KernelT int32
|
||||
}
|
||||
|
||||
// newCausalConv3d creates a 3D causal conv
|
||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
||||
weight, err := weights.Get(prefix + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
||||
}
|
||||
bias, _ := weights.Get(prefix + ".bias")
|
||||
|
||||
kernelT := weight.Shape()[2]
|
||||
outC := weight.Shape()[0]
|
||||
|
||||
var biasReshaped *mlx.Array
|
||||
if bias != nil {
|
||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
||||
}
|
||||
|
||||
return &CausalConv3d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
BiasReshaped: biasReshaped,
|
||||
KernelT: kernelT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies causal 3D convolution (or 2D if weight is 4D)
|
||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := c.Weight.Shape()
|
||||
|
||||
// Handle both 5D (3D conv) and 4D (2D conv) weights
|
||||
if len(shape) == 4 {
|
||||
// 2D conv: [O, I, kH, kW] - need to apply per-frame
|
||||
return c.forward2D(x)
|
||||
}
|
||||
|
||||
// 3D conv: [O, I, kT, kH, kW]
|
||||
kernelT := shape[2]
|
||||
kernelH := shape[3]
|
||||
kernelW := shape[4]
|
||||
|
||||
// Causal temporal padding, same spatial padding
|
||||
padT := kernelT - 1
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Stage 1: Pad
|
||||
{
|
||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
var out *mlx.Array
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
|
||||
func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
|
||||
xShape := x.Shape()
|
||||
B := xShape[0]
|
||||
T := xShape[1]
|
||||
H := xShape[2]
|
||||
W := xShape[3]
|
||||
C := xShape[4]
|
||||
|
||||
wShape := c.Weight.Shape() // [O, I, kH, kW]
|
||||
kernelH := wShape[2]
|
||||
kernelW := wShape[3]
|
||||
outC := wShape[0]
|
||||
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Pad spatially
|
||||
x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
|
||||
|
||||
// Apply 2D conv
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
||||
x = mlx.Conv2d(x, weight, 1, 0)
|
||||
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
// Get output spatial dims
|
||||
outH := H
|
||||
outW := W
|
||||
|
||||
// Reshape back to [B, T, H, W, C]
|
||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// RMSNorm3D applies RMS normalization over channels
|
||||
type RMSNorm3D struct {
|
||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
||||
}
|
||||
|
||||
// newRMSNorm3D creates an RMS norm
|
||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
||||
gamma, err := weights.Get(prefix + ".gamma")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
||||
return &RMSNorm3D{Gamma: gamma}, nil
|
||||
}
|
||||
|
||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
||||
return mlx.Mul(normalized, n.Gamma)
|
||||
}
|
||||
|
||||
// ResBlock is a residual block with RMS norm and causal convs
|
||||
type ResBlock struct {
|
||||
Norm1 *RMSNorm3D
|
||||
Conv1 *CausalConv3d
|
||||
Norm2 *RMSNorm3D
|
||||
Conv2 *CausalConv3d
|
||||
Shortcut *CausalConv3d
|
||||
}
|
||||
|
||||
// newResBlock creates a residual block
|
||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shortcut *CausalConv3d
|
||||
if inDim != outDim {
|
||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ResBlock{
|
||||
Norm1: norm1,
|
||||
Conv1: conv1,
|
||||
Norm2: norm2,
|
||||
Conv2: conv2,
|
||||
Shortcut: shortcut,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the residual block
|
||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
|
||||
mlx.Keep(x)
|
||||
|
||||
// Stage 1: norm1 + silu
|
||||
{
|
||||
h = r.Norm1.Forward(x)
|
||||
h = silu3D(h)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: conv1
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 3: norm2 + silu
|
||||
{
|
||||
prev := h
|
||||
h = r.Norm2.Forward(h)
|
||||
h = silu3D(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: conv2
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
if r.Shortcut != nil {
|
||||
shortcut := r.Shortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
mlx.Eval(h)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// AttentionBlock is a 2D attention block
|
||||
type AttentionBlock struct {
|
||||
Norm *RMSNorm3D
|
||||
ToQKV *mlx.Array
|
||||
ToQKVBias *mlx.Array
|
||||
Proj *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// newAttentionBlock creates an attention block
|
||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
||||
|
||||
return &AttentionBlock{
|
||||
Norm: norm,
|
||||
ToQKV: toQKV,
|
||||
ToQKVBias: toQKVBias,
|
||||
Proj: proj,
|
||||
ProjBias: projBias,
|
||||
Dim: dim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies 2D attention
|
||||
// Input: [B, T, H, W, C] (channels-last)
|
||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
identity := x
|
||||
|
||||
// Flatten to [B*T, 1, H, W, C] for norm
|
||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
||||
x = a.Norm.Forward(x)
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Flatten spatial to [B*T, H*W, C]
|
||||
x = mlx.Reshape(x, B*T, H*W, C)
|
||||
|
||||
// Linear to get Q, K, V
|
||||
wShape := a.ToQKV.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = a.ToQKV
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0)
|
||||
|
||||
qkv := mlx.Linear(x, w)
|
||||
if a.ToQKVBias != nil {
|
||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
||||
}
|
||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
||||
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
out = mlx.Reshape(out, B*T, H*W, C)
|
||||
|
||||
// Project back
|
||||
pShape := a.Proj.Shape()
|
||||
var p *mlx.Array
|
||||
if len(pShape) == 4 {
|
||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
||||
} else {
|
||||
p = a.Proj
|
||||
}
|
||||
p = mlx.Transpose(p, 1, 0)
|
||||
out = mlx.Linear(out, p)
|
||||
if a.ProjBias != nil {
|
||||
out = mlx.Add(out, a.ProjBias)
|
||||
}
|
||||
|
||||
out = mlx.Reshape(out, B, T, H, W, C)
|
||||
return mlx.Add(out, identity)
|
||||
}
|
||||
|
||||
// UpBlock handles upsampling in decoder
|
||||
type UpBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Upsampler *Upsample
|
||||
}
|
||||
|
||||
// newUpBlock creates an up block
|
||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var upsampler *Upsample
|
||||
if upsampleMode != "" {
|
||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
||||
}
|
||||
|
||||
return &UpBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Upsampler: upsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies up block
|
||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, block := range u.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if u.Upsampler != nil {
|
||||
prev := x
|
||||
x = u.Upsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Upsample handles spatial upsampling
|
||||
type Upsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newUpsample creates an upsampler
|
||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Upsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := u.Conv.Shape()[0]
|
||||
|
||||
// Stage 1: 2x nearest neighbor upsample
|
||||
{
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
x = upsample2xChannelsLast(x)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
||||
if u.Bias != nil {
|
||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// MidBlock is the middle block
|
||||
type MidBlock struct {
|
||||
ResBlock1 *ResBlock
|
||||
Attention *AttentionBlock
|
||||
ResBlock2 *ResBlock
|
||||
}
|
||||
|
||||
// newMidBlock creates a mid block
|
||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MidBlock{
|
||||
ResBlock1: res1,
|
||||
Attention: attn,
|
||||
ResBlock2: res2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies mid block
|
||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
prev := x
|
||||
x = m.ResBlock1.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.ResBlock2.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func silu3D(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
||||
}
|
||||
|
||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
||||
}
|
||||
|
||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
x = mlx.Take(x, rowIdx, 1)
|
||||
x = mlx.Take(x, colIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
return mlx.Conv2d(x, weight, 1, 0)
|
||||
}
|
||||
|
||||
// conv2DStrided applies conv with stride > 1 using manual patch extraction
|
||||
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
|
||||
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
wShape := weight.Shape()
|
||||
Cout := wShape[0]
|
||||
kH := wShape[1]
|
||||
kW := wShape[2]
|
||||
|
||||
outH := (H - kH) / stride + 1
|
||||
outW := (W - kW) / stride + 1
|
||||
|
||||
patches := extractPatches2DStrided(x, kH, kW, stride)
|
||||
wFlat := mlx.Reshape(weight, Cout, -1)
|
||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
return mlx.Reshape(out, B, outH, outW, Cout)
|
||||
}
|
||||
|
||||
// conv3DStrided applies 3D conv with strides using manual patch extraction
|
||||
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
|
||||
// strideT, strideH, strideW are the strides for each dimension
|
||||
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
|
||||
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
wShape := weight.Shape()
|
||||
Cout := wShape[0]
|
||||
// I := wShape[1]
|
||||
kT := wShape[2]
|
||||
kH := wShape[3]
|
||||
kW := wShape[4]
|
||||
|
||||
// For temporal: if T < kT, we need to repeat frames temporally
|
||||
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
|
||||
// Python Qwen2.5-VL duplicates the frame, not zero-pads
|
||||
if T < kT {
|
||||
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
|
||||
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
|
||||
T = kT
|
||||
}
|
||||
|
||||
outT := (T - kT) / strideT + 1
|
||||
outH := (H - kH) / strideH + 1
|
||||
outW := (W - kW) / strideW + 1
|
||||
|
||||
// Extract 3D patches in [C, T, H, W] order to match Python
|
||||
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
|
||||
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
|
||||
|
||||
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
|
||||
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
|
||||
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
return mlx.Reshape(out, B, outT, outH, outW, Cout)
|
||||
}
|
||||
|
||||
// extractPatches3DStrided extracts 3D patches with given strides
|
||||
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
|
||||
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
outT := (T - kT) / strideT + 1
|
||||
outH := (H - kH) / strideH + 1
|
||||
outW := (W - kW) / strideW + 1
|
||||
|
||||
numPatches := outT * outH * outW
|
||||
patches := make([]*mlx.Array, numPatches)
|
||||
idx := 0
|
||||
for t := int32(0); t < outT; t++ {
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startT := t * strideT
|
||||
startH := i * strideH
|
||||
startW := j * strideW
|
||||
// Extract patch: [B, kT, kH, kW, C]
|
||||
patch := mlx.Slice(x,
|
||||
[]int32{0, startT, startH, startW, 0},
|
||||
[]int32{B, startT + kT, startH + kH, startW + kW, C})
|
||||
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
|
||||
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
|
||||
// Flatten to [B, C*T*H*W]
|
||||
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
|
||||
}
|
||||
|
||||
// extractPatches2DStrided extracts patches with given stride
|
||||
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
outH := (H - kH) / stride + 1
|
||||
outW := (W - kW) / stride + 1
|
||||
|
||||
patches := make([]*mlx.Array, outH*outW)
|
||||
idx := 0
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startH := i * stride
|
||||
startW := j * stride
|
||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
@@ -1,475 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"golang.org/x/image/draw"
|
||||
_ "golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// loadImageFile loads an image from disk
|
||||
func loadImageFile(path string) (image.Image, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open image: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
img, _, err := image.Decode(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode image: %w", err)
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range
|
||||
func imageToFloat32Pixels(img image.Image, width, height int) []float32 {
|
||||
pixels := make([]float32, width*height*3)
|
||||
idx := 0
|
||||
for y := 0; y < height; y++ {
|
||||
for x := 0; x < width; x++ {
|
||||
r, g, b, _ := img.At(x, y).RGBA()
|
||||
pixels[idx] = float32(r) / 65535.0
|
||||
pixels[idx+1] = float32(g) / 65535.0
|
||||
pixels[idx+2] = float32(b) / 65535.0
|
||||
idx += 3
|
||||
}
|
||||
}
|
||||
return pixels
|
||||
}
|
||||
|
||||
// normalizeImageNet applies ImageNet normalization to an image tensor
|
||||
func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array {
|
||||
mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3})
|
||||
std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3})
|
||||
return mlx.Div(mlx.Sub(arr, mean), std)
|
||||
}
|
||||
|
||||
// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16
|
||||
func prepareImageTensor(arr *mlx.Array) *mlx.Array {
|
||||
// Transpose to [C, H, W] and make contiguous
|
||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
||||
// Add batch dimension [1, C, H, W]
|
||||
arr = mlx.ExpandDims(arr, 0)
|
||||
// Convert to bf16
|
||||
arr = mlx.ToBFloat16(arr)
|
||||
mlx.Eval(arr)
|
||||
return arr
|
||||
}
|
||||
|
||||
// clampFloat clamps a value to [0, 255] and returns uint8
|
||||
func clampFloat(v, weightSum float64) uint8 {
|
||||
v /= weightSum
|
||||
if v < 0 {
|
||||
v = 0
|
||||
}
|
||||
if v > 255 {
|
||||
v = 255
|
||||
}
|
||||
return uint8(math.Round(v))
|
||||
}
|
||||
|
||||
// ImageDims holds dimensions for a preprocessed image
|
||||
type ImageDims struct {
|
||||
// Original image dimensions
|
||||
OrigW, OrigH int32
|
||||
// Condition image dimensions (for vision encoder)
|
||||
CondW, CondH int32
|
||||
// VAE image dimensions
|
||||
VaeW, VaeH int32
|
||||
// Latent dimensions (VAE dims / vae_scale_factor)
|
||||
LatentW, LatentH int32
|
||||
// Patch dimensions (latent dims / patch_size)
|
||||
PatchW, PatchH int32
|
||||
}
|
||||
|
||||
// ProcessorConfig holds image processor configuration
|
||||
type ProcessorConfig struct {
|
||||
// Condition image size (target pixel area for vision encoder input)
|
||||
// Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456
|
||||
// Pipeline resizes image to this area before passing to encode_prompt
|
||||
ConditionImageSize int32
|
||||
|
||||
// VAE image size (target pixel area)
|
||||
// Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576
|
||||
VAEImageSize int32
|
||||
|
||||
// Image normalization (ImageNet stats for vision encoder)
|
||||
ImageMean []float32
|
||||
ImageStd []float32
|
||||
}
|
||||
|
||||
// defaultProcessorConfig returns default processor config
|
||||
func defaultProcessorConfig() *ProcessorConfig {
|
||||
return &ProcessorConfig{
|
||||
ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE
|
||||
VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE
|
||||
ImageMean: []float32{0.48145466, 0.4578275, 0.40821073},
|
||||
ImageStd: []float32{0.26862954, 0.26130258, 0.27577711},
|
||||
}
|
||||
}
|
||||
|
||||
// Processor handles image preprocessing for Qwen-Image-Edit
|
||||
type Processor struct {
|
||||
Config *ProcessorConfig
|
||||
}
|
||||
|
||||
// Load loads the processor config
|
||||
func (p *Processor) Load(path string) error {
|
||||
p.Config = defaultProcessorConfig()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAndPreprocess loads an image and preprocesses it for both paths
|
||||
// Returns: condImage (for vision encoder), vaeImage (for VAE encoding)
|
||||
func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) {
|
||||
img, err := loadImageFile(imagePath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
origW := bounds.Dx()
|
||||
origH := bounds.Dy()
|
||||
ratio := float64(origW) / float64(origH)
|
||||
|
||||
// Calculate dimensions for condition image (vision encoder)
|
||||
// Python pipeline does TWO resizes:
|
||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
||||
finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
||||
|
||||
// Calculate dimensions for VAE image (1024x1024 area)
|
||||
// Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32)
|
||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
||||
|
||||
// Preprocess for condition (vision encoder) - two-step resize
|
||||
condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH)
|
||||
|
||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
||||
vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH)
|
||||
|
||||
return condImage, vaeImage, nil
|
||||
}
|
||||
|
||||
// preprocessImageLanczos does single-step Lanczos resize for vision encoder
|
||||
// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default)
|
||||
// Used by edit_plus pipeline for multi-image input
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array {
|
||||
resized := resizeImageLanczos(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
arr = p.normalizeImageNet(arr)
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline
|
||||
// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize)
|
||||
// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize)
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array {
|
||||
intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH))
|
||||
resized := resizeImageBicubic(intermediate, int(finalW), int(finalH))
|
||||
pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH))
|
||||
arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3})
|
||||
arr = p.normalizeImageNet(arr)
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImage converts image to tensor for vision encoder
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array {
|
||||
resized := resizeImageBicubic(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
if normalize {
|
||||
arr = p.normalizeImageNet(arr)
|
||||
}
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImageForVAE converts image to tensor for VAE encoding
|
||||
// Returns: [B, C, T, H, W] tensor in [-1, 1] range
|
||||
func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array {
|
||||
resized := resizeImageLanczos(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
|
||||
// Scale to [-1, 1]: arr * 2 - 1
|
||||
arr = mlx.MulScalar(arr, 2.0)
|
||||
arr = mlx.AddScalar(arr, -1.0)
|
||||
|
||||
// Transpose to [C, H, W] and make contiguous
|
||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
||||
|
||||
// Add batch and temporal dimensions [1, C, 1, H, W]
|
||||
arr = mlx.ExpandDims(arr, 0) // [1, C, H, W]
|
||||
arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W]
|
||||
|
||||
arr = mlx.ToBFloat16(arr)
|
||||
mlx.Eval(arr)
|
||||
return arr
|
||||
}
|
||||
|
||||
// smartResize implements Python Qwen2VL processor's smart_resize logic
|
||||
// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints
|
||||
func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) {
|
||||
// Round to factor
|
||||
hBar := int32(math.Round(float64(height)/float64(factor))) * factor
|
||||
wBar := int32(math.Round(float64(width)/float64(factor))) * factor
|
||||
|
||||
// Ensure minimum factor size
|
||||
if hBar < factor {
|
||||
hBar = factor
|
||||
}
|
||||
if wBar < factor {
|
||||
wBar = factor
|
||||
}
|
||||
|
||||
// Check pixel constraints
|
||||
total := hBar * wBar
|
||||
if total > maxPixels {
|
||||
// Scale down
|
||||
beta := math.Sqrt(float64(maxPixels) / float64(total))
|
||||
hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor
|
||||
} else if total < minPixels {
|
||||
// Scale up
|
||||
beta := math.Sqrt(float64(minPixels) / float64(total))
|
||||
hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
// calculateDimensions calculates width and height for a target area while maintaining ratio
|
||||
// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge)
|
||||
func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) {
|
||||
width := math.Sqrt(float64(targetArea) * ratio)
|
||||
height := width / ratio
|
||||
|
||||
m := float64(multiple)
|
||||
width = math.Round(width/m) * m
|
||||
height = math.Round(height/m) * m
|
||||
|
||||
// Ensure minimum dimensions
|
||||
if width < m {
|
||||
width = m
|
||||
}
|
||||
if height < m {
|
||||
height = m
|
||||
}
|
||||
|
||||
return int32(width), int32(height)
|
||||
}
|
||||
|
||||
// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS)
|
||||
func resizeImageLanczos(img image.Image, width, height int) image.Image {
|
||||
bounds := img.Bounds()
|
||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
// Lanczos3 kernel (a=3) to match PIL.LANCZOS
|
||||
lanczos3 := &draw.Kernel{
|
||||
Support: 3.0,
|
||||
At: func(t float64) float64 {
|
||||
if t == 0 {
|
||||
return 1.0
|
||||
}
|
||||
if t < 0 {
|
||||
t = -t
|
||||
}
|
||||
if t >= 3.0 {
|
||||
return 0.0
|
||||
}
|
||||
// sinc(t) * sinc(t/3)
|
||||
piT := math.Pi * t
|
||||
return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3))
|
||||
},
|
||||
}
|
||||
lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC)
|
||||
// Uses separable interpolation with PIL's coordinate mapping for exact match
|
||||
func resizeImageBicubic(img image.Image, width, height int) image.Image {
|
||||
bounds := img.Bounds()
|
||||
srcW := bounds.Dx()
|
||||
srcH := bounds.Dy()
|
||||
|
||||
// Convert to RGBA if needed
|
||||
var src *image.RGBA
|
||||
if rgba, ok := img.(*image.RGBA); ok {
|
||||
src = rgba
|
||||
} else {
|
||||
src = image.NewRGBA(bounds)
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
src.Set(x, y, img.At(x, y))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keys cubic with a=-0.5 (PIL BICUBIC)
|
||||
cubic := func(x float64) float64 {
|
||||
if x < 0 {
|
||||
x = -x
|
||||
}
|
||||
if x < 1 {
|
||||
return 1.5*x*x*x - 2.5*x*x + 1
|
||||
}
|
||||
if x < 2 {
|
||||
return -0.5*x*x*x + 2.5*x*x - 4*x + 2
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Horizontal pass: srcW -> width, keep srcH rows
|
||||
temp := image.NewRGBA(image.Rect(0, 0, width, srcH))
|
||||
for y := 0; y < srcH; y++ {
|
||||
for dstX := 0; dstX < width; dstX++ {
|
||||
// PIL coordinate mapping: center-to-center
|
||||
srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5
|
||||
baseX := int(math.Floor(srcXf))
|
||||
|
||||
var sumR, sumG, sumB, sumA, weightSum float64
|
||||
for i := -1; i <= 2; i++ {
|
||||
sx := baseX + i
|
||||
if sx < 0 {
|
||||
sx = 0
|
||||
}
|
||||
if sx >= srcW {
|
||||
sx = srcW - 1
|
||||
}
|
||||
|
||||
w := cubic(math.Abs(srcXf - float64(baseX+i)))
|
||||
c := src.RGBAAt(sx, y)
|
||||
sumR += float64(c.R) * w
|
||||
sumG += float64(c.G) * w
|
||||
sumB += float64(c.B) * w
|
||||
sumA += float64(c.A) * w
|
||||
weightSum += w
|
||||
}
|
||||
|
||||
temp.SetRGBA(dstX, y, color.RGBA{
|
||||
clampFloat(sumR, weightSum),
|
||||
clampFloat(sumG, weightSum),
|
||||
clampFloat(sumB, weightSum),
|
||||
clampFloat(sumA, weightSum),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Vertical pass: srcH -> height
|
||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
for x := 0; x < width; x++ {
|
||||
for dstY := 0; dstY < height; dstY++ {
|
||||
srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5
|
||||
baseY := int(math.Floor(srcYf))
|
||||
|
||||
var sumR, sumG, sumB, sumA, weightSum float64
|
||||
for j := -1; j <= 2; j++ {
|
||||
sy := baseY + j
|
||||
if sy < 0 {
|
||||
sy = 0
|
||||
}
|
||||
if sy >= srcH {
|
||||
sy = srcH - 1
|
||||
}
|
||||
|
||||
w := cubic(math.Abs(srcYf - float64(baseY+j)))
|
||||
c := temp.RGBAAt(x, sy)
|
||||
sumR += float64(c.R) * w
|
||||
sumG += float64(c.G) * w
|
||||
sumB += float64(c.B) * w
|
||||
sumA += float64(c.A) * w
|
||||
weightSum += w
|
||||
}
|
||||
|
||||
dst.SetRGBA(x, dstY, color.RGBA{
|
||||
clampFloat(sumR, weightSum),
|
||||
clampFloat(sumG, weightSum),
|
||||
clampFloat(sumB, weightSum),
|
||||
clampFloat(sumA, weightSum),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// LoadAndPreprocessMultiple loads multiple images and preprocesses them
|
||||
// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions)
|
||||
func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) {
|
||||
const vaeScaleFactor int32 = 8
|
||||
const patchSize int32 = 2
|
||||
|
||||
condImages := make([]*mlx.Array, len(imagePaths))
|
||||
vaeImages := make([]*mlx.Array, len(imagePaths))
|
||||
dims := make([]ImageDims, len(imagePaths))
|
||||
|
||||
for i, imagePath := range imagePaths {
|
||||
img, err := loadImageFile(imagePath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("image %d: %w", i, err)
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
origW := int32(bounds.Dx())
|
||||
origH := int32(bounds.Dy())
|
||||
ratio := float64(origW) / float64(origH)
|
||||
|
||||
// Calculate dimensions for condition image (vision encoder)
|
||||
// Python pipeline does TWO resizes:
|
||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
||||
condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
||||
|
||||
// Calculate dimensions for VAE image (1024x1024 area)
|
||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
||||
|
||||
// Calculate derived dimensions
|
||||
latentW := vaeW / vaeScaleFactor
|
||||
latentH := vaeH / vaeScaleFactor
|
||||
patchW := latentW / patchSize
|
||||
patchH := latentH / patchSize
|
||||
|
||||
dims[i] = ImageDims{
|
||||
OrigW: origW,
|
||||
OrigH: origH,
|
||||
CondW: condW,
|
||||
CondH: condH,
|
||||
VaeW: vaeW,
|
||||
VaeH: vaeH,
|
||||
LatentW: latentW,
|
||||
LatentH: latentH,
|
||||
PatchW: patchW,
|
||||
PatchH: patchH,
|
||||
}
|
||||
|
||||
fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n",
|
||||
i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH)
|
||||
|
||||
// Preprocess for condition (vision encoder) - two-step resize to match Python pipeline
|
||||
condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH)
|
||||
|
||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
||||
vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH)
|
||||
}
|
||||
|
||||
return condImages, vaeImages, dims, nil
|
||||
}
|
||||
@@ -1,625 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing.
|
||||
// It reuses components from qwen_image where possible.
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image editing.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
||||
Width int32 // Output width (default: from input image)
|
||||
Height int32 // Output height (default: from input image)
|
||||
Steps int // Denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
}
|
||||
|
||||
// Model represents a Qwen-Image-Edit diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Processor *Processor // Image processor for vision encoder
|
||||
TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image)
|
||||
Transformer *qwen_image.Transformer // Reuse qwen_image transformer
|
||||
VAE *VAE // Combined encoder + decoder
|
||||
}
|
||||
|
||||
// Load loads the Qwen-Image-Edit model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Qwen-Image-Edit model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer from processor directory
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
processorPath := filepath.Join(modelPath, "processor")
|
||||
tok, err := tokenizer.Load(processorPath)
|
||||
if err != nil {
|
||||
// Fallback to tokenizer directory
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
||||
tok, err = tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load processor (image preprocessing config)
|
||||
fmt.Print(" Loading processor... ")
|
||||
m.Processor = &Processor{}
|
||||
if err := m.Processor.Load(processorPath); err != nil {
|
||||
return fmt.Errorf("processor: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load vision-language text encoder (Qwen2.5-VL from qwen_image package)
|
||||
m.TextEncoder = &qwen_image.Qwen25VL{}
|
||||
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer (reuse qwen_image)
|
||||
m.Transformer = &qwen_image.Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE (encoder + decoder)
|
||||
m.VAE = &VAE{}
|
||||
if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAE)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
peak := mlx.MetalGetPeakMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
||||
time.Since(start).Seconds(),
|
||||
float64(mem)/(1024*1024*1024),
|
||||
float64(peak)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Edit edits an image based on a text prompt.
|
||||
// inputImagePath: path to input image
|
||||
// prompt: text description of desired edit
|
||||
func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// EditFromConfig edits images using the unified config struct.
|
||||
// Accepts one or more input images.
|
||||
func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
if len(inputImagePaths) == 0 {
|
||||
return nil, fmt.Errorf("no input images provided")
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result, err := m.edit(inputImagePaths, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n",
|
||||
len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n",
|
||||
len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EditImage implements model.ImageEditModel interface.
|
||||
func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Edit(inputImagePath, prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// EditMultiImage edits using multiple source images.
|
||||
// This matches diffusers' QwenImageEditPlusPipeline behavior.
|
||||
func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
return m.EditFromConfig(inputImagePaths, cfg)
|
||||
}
|
||||
|
||||
// edit is the internal editing pipeline that handles one or more images.
|
||||
func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
|
||||
// Load and preprocess all input images
|
||||
fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths))
|
||||
condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preprocess images: %w", err)
|
||||
}
|
||||
for _, img := range condImages {
|
||||
mlx.Keep(img)
|
||||
}
|
||||
for _, img := range vaeImages {
|
||||
mlx.Keep(img)
|
||||
}
|
||||
mlx.Eval(append(condImages, vaeImages...)...)
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.Config
|
||||
vaeScaleFactor := int32(8)
|
||||
|
||||
// Output dimensions - if not specified, use first input image dimensions
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = inputDims[0].VaeW
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = inputDims[0].VaeH
|
||||
}
|
||||
|
||||
// Output (noise) latent dimensions
|
||||
outLatentH := cfg.Height / vaeScaleFactor
|
||||
outLatentW := cfg.Width / vaeScaleFactor
|
||||
outPH := outLatentH / tcfg.PatchSize
|
||||
outPW := outLatentW / tcfg.PatchSize
|
||||
noiseSeqLen := outPH * outPW
|
||||
imgSeqLen := noiseSeqLen
|
||||
|
||||
// Encode prompt with all images for conditioning
|
||||
posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding prompt: %w", err)
|
||||
}
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
|
||||
var negEmb *mlx.Array
|
||||
if useCFG {
|
||||
negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding negative prompt: %w", err)
|
||||
}
|
||||
mlx.Keep(negEmb)
|
||||
mlx.Eval(negEmb)
|
||||
}
|
||||
|
||||
// Pad sequences to same length for CFG
|
||||
txtLen := posEmb.Shape()[1]
|
||||
if useCFG {
|
||||
negLen := negEmb.Shape()[1]
|
||||
if negLen > txtLen {
|
||||
txtLen = negLen
|
||||
}
|
||||
if posEmb.Shape()[1] < txtLen {
|
||||
posEmb = padSequence(posEmb, txtLen)
|
||||
}
|
||||
if negEmb.Shape()[1] < txtLen {
|
||||
negEmb = padSequence(negEmb, txtLen)
|
||||
}
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// Encode all input images to latents and concatenate
|
||||
fmt.Println("Encoding images to latents...")
|
||||
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
|
||||
for i, vaeImage := range vaeImages {
|
||||
imageLatents := m.VAE.Encode(vaeImage)
|
||||
imageLatents = m.VAE.Normalize(imageLatents)
|
||||
imageLatents2D := mlx.Squeeze(imageLatents, 2)
|
||||
packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize)
|
||||
mlx.Keep(packed)
|
||||
mlx.Eval(packed)
|
||||
allImageLatentsPacked[i] = packed
|
||||
}
|
||||
|
||||
imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1)
|
||||
mlx.Keep(imageLatentsPacked)
|
||||
mlx.Eval(imageLatentsPacked)
|
||||
|
||||
// Scheduler
|
||||
scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, noiseSeqLen)
|
||||
|
||||
// Init noise latents in packed format
|
||||
packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize
|
||||
packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed)
|
||||
latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize)
|
||||
mlx.Eval(latents)
|
||||
|
||||
// RoPE cache
|
||||
ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope)
|
||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Denoising loop
|
||||
fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
t := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
||||
mlx.Eval(timestep)
|
||||
|
||||
latents2D := mlx.Squeeze(latents, 2)
|
||||
patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize)
|
||||
latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Tile inputs: [1, L, D] -> [2, L, D]
|
||||
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
// Single batched forward pass
|
||||
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
D := batchedOutput.Shape()[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
|
||||
|
||||
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
|
||||
} else {
|
||||
output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]})
|
||||
}
|
||||
|
||||
noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds())
|
||||
}
|
||||
|
||||
// Free denoising temporaries
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
imageLatentsPacked.Free()
|
||||
|
||||
// Decode latents
|
||||
decoded := m.decodeAndPostprocess(latents)
|
||||
latents.Free()
|
||||
|
||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling.
|
||||
// This prevents CFG from inflating magnitude too much.
|
||||
func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array {
|
||||
// Upcast to float32 for precision
|
||||
posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32)
|
||||
negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32)
|
||||
|
||||
// CFG: pred = neg + scale * (pos - neg)
|
||||
diff := mlx.Sub(posF32, negF32)
|
||||
scaledDiff := mlx.MulScalar(diff, scale)
|
||||
combPred := mlx.Add(negF32, scaledDiff)
|
||||
|
||||
// Norm rescaling: rescale combined prediction to match conditional norm
|
||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true))
|
||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
||||
output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
||||
|
||||
mlx.Eval(output)
|
||||
return mlx.ToBFloat16(output)
|
||||
}
|
||||
|
||||
// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1].
|
||||
func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array {
|
||||
latents = m.VAE.Denormalize(latents)
|
||||
decoded := m.VAE.Decode(latents)
|
||||
|
||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
||||
decoded = mlx.Squeeze(decoded, 2)
|
||||
decoded = mlx.AddScalar(decoded, 1.0)
|
||||
decoded = mlx.DivScalar(decoded, 2.0)
|
||||
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
|
||||
mlx.Eval(decoded)
|
||||
return decoded
|
||||
}
|
||||
|
||||
// padSequence pads a sequence tensor to the target length with zeros
|
||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
// Pad on sequence dimension (axis 1)
|
||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
||||
}
|
||||
|
||||
// LoadPersistent is an alias for backward compatibility.
|
||||
func LoadPersistent(modelPath string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// PrepareRoPEMultiImage computes RoPE with interpolation for image editing.
|
||||
// Handles single or multiple input images with different resolutions.
|
||||
//
|
||||
// Parameters:
|
||||
// - outPH, outPW: output patch dimensions (noise latent resolution)
|
||||
// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...]
|
||||
// - txtLen: text sequence length
|
||||
// - axesDims: RoPE axis dimensions [16, 56, 56]
|
||||
//
|
||||
// Returns RoPE cache where:
|
||||
// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions
|
||||
// - First outPH*outPW positions are for noise latents (standard RoPE at output res)
|
||||
// - Following positions are for each input image (interpolated from output res)
|
||||
func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache {
|
||||
theta := float64(10000)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Compute base frequencies for each axis dimension
|
||||
freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta)
|
||||
freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta)
|
||||
freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta)
|
||||
|
||||
// Build frequency lookup tables
|
||||
posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
||||
posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false)
|
||||
posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false)
|
||||
negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image
|
||||
negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true)
|
||||
negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true)
|
||||
|
||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
||||
|
||||
// Helper to compute RoPE for a single position at output resolution with scale_rope
|
||||
computePosFreqs := func(framePos, y, x int32) []float32 {
|
||||
row := make([]float32, headDim)
|
||||
idx := 0
|
||||
|
||||
// Frame position
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
row[idx+i] = posFreqsT[framePos][i]
|
||||
}
|
||||
idx += len(freqsT) * 2
|
||||
|
||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
||||
outHHalf := outPH / 2
|
||||
hNegCount := outPH - outHHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += len(freqsH) * 2
|
||||
|
||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
||||
outWHalf := outPW / 2
|
||||
wNegCount := outPW - outWHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
// Helper to compute RoPE for frame -1 (used for last condition image)
|
||||
// This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:]
|
||||
computeNegFrameFreqs := func(y, x int32) []float32 {
|
||||
row := make([]float32, headDim)
|
||||
idx := 0
|
||||
|
||||
// Frame -1: use last row of negative frame frequencies
|
||||
negFrameIdx := maxIdx - 1
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
row[idx+i] = negFreqsT[negFrameIdx][i]
|
||||
}
|
||||
idx += len(freqsT) * 2
|
||||
|
||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
||||
outHHalf := outPH / 2
|
||||
hNegCount := outPH - outHHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += len(freqsH) * 2
|
||||
|
||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
||||
outWHalf := outPW / 2
|
||||
wNegCount := outPW - outWHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
// Total image sequence length: noise + all input images
|
||||
noiseSeqLen := outPH * outPW
|
||||
totalImgLen := noiseSeqLen
|
||||
for _, dims := range inputDims {
|
||||
totalImgLen += dims.PatchH * dims.PatchW
|
||||
}
|
||||
|
||||
imgFreqsData := make([]float32, totalImgLen*headDim)
|
||||
idx := int32(0)
|
||||
|
||||
// Segment 0: Noise latents - standard RoPE at output resolution (frame 0)
|
||||
for y := int32(0); y < outPH; y++ {
|
||||
for x := int32(0); x < outPW; x++ {
|
||||
row := computePosFreqs(0, y, x)
|
||||
copy(imgFreqsData[idx:], row)
|
||||
idx += headDim
|
||||
}
|
||||
}
|
||||
|
||||
// Segments 1..N: Edit image latents - INTERPOLATED RoPE
|
||||
// For single image: use frame 1 (matches original PrepareRoPEInterpolated)
|
||||
// For multiple images: Python uses frame -1 for the LAST condition image
|
||||
// (_compute_condition_freqs), positive indices for others.
|
||||
numImages := len(inputDims)
|
||||
lastImgIdx := numImages - 1
|
||||
for imgIdx, dims := range inputDims {
|
||||
inPH := dims.PatchH
|
||||
inPW := dims.PatchW
|
||||
|
||||
// Determine frame index for this image
|
||||
// Single image case: use frame 1 (like original PrepareRoPEInterpolated)
|
||||
// Multi-image case: last image uses frame -1, others use frame 1, 2, etc.
|
||||
useNegFrame := numImages > 1 && imgIdx == lastImgIdx
|
||||
|
||||
// Map each input position to an output position using linear interpolation
|
||||
for y := int32(0); y < inPH; y++ {
|
||||
for x := int32(0); x < inPW; x++ {
|
||||
// Interpolate: map input (y, x) to output grid position
|
||||
// This is the key fix from DiffSynth's forward_sampling
|
||||
var yOut, xOut int32
|
||||
if inPH == 1 {
|
||||
yOut = 0
|
||||
} else {
|
||||
// Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1)
|
||||
yOut = y * (outPH - 1) / (inPH - 1)
|
||||
}
|
||||
if inPW == 1 {
|
||||
xOut = 0
|
||||
} else {
|
||||
xOut = x * (outPW - 1) / (inPW - 1)
|
||||
}
|
||||
|
||||
var row []float32
|
||||
if useNegFrame {
|
||||
// Last image in multi-image uses frame -1
|
||||
row = computeNegFrameFreqs(yOut, xOut)
|
||||
} else {
|
||||
// Single image uses frame 1, multi-image uses frame 1, 2, etc.
|
||||
frameIdx := int32(imgIdx + 1)
|
||||
row = computePosFreqs(frameIdx, yOut, xOut)
|
||||
}
|
||||
copy(imgFreqsData[idx:], row)
|
||||
idx += headDim
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim})
|
||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
||||
|
||||
// Text frequencies - start after max video index
|
||||
maxVidIdx := max(outPH/2, outPW/2)
|
||||
|
||||
txtFreqsData := make([]float32, txtLen*headDim)
|
||||
idx = 0
|
||||
for t := int32(0); t < txtLen; t++ {
|
||||
pos := maxVidIdx + t
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
|
||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
||||
|
||||
return &qwen_image.RoPECache{
|
||||
ImgFreqs: imgFreqs,
|
||||
TxtFreqs: txtFreqs,
|
||||
}
|
||||
}
|
||||
@@ -1,249 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping qwen_image_edit tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestComputeAxisFreqs verifies frequency computation matches Python reference
|
||||
func TestComputeAxisFreqs(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
|
||||
// Expected values from Python:
|
||||
// freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim))
|
||||
expectedFreqsT := []float64{
|
||||
1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684,
|
||||
0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017,
|
||||
}
|
||||
|
||||
expectedFreqsH_first4 := []float64{
|
||||
1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494,
|
||||
}
|
||||
|
||||
expectedFreqsH_last4 := []float64{
|
||||
0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437,
|
||||
}
|
||||
|
||||
// Test temporal frequencies (dim=16)
|
||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
||||
if len(freqsT) != 8 {
|
||||
t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT))
|
||||
}
|
||||
for i, expected := range expectedFreqsT {
|
||||
if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff)
|
||||
}
|
||||
}
|
||||
|
||||
// Test height/width frequencies (dim=56)
|
||||
freqsH := qwen_image.ComputeAxisFreqs(56, theta)
|
||||
if len(freqsH) != 28 {
|
||||
t.Fatalf("expected 28 height frequencies, got %d", len(freqsH))
|
||||
}
|
||||
for i, expected := range expectedFreqsH_first4 {
|
||||
if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff)
|
||||
}
|
||||
}
|
||||
for i, expected := range expectedFreqsH_last4 {
|
||||
idx := 24 + i // last 4 of 28
|
||||
if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions
|
||||
func TestMakeFreqTable(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Test positive table
|
||||
posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
||||
|
||||
// Position 0 should give cos=1, sin=0 for all frequencies
|
||||
for i := 0; i < len(freqsT)*2; i += 2 {
|
||||
if posTable[0][i] != 1.0 {
|
||||
t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i])
|
||||
}
|
||||
if posTable[0][i+1] != 0.0 {
|
||||
t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1])
|
||||
}
|
||||
}
|
||||
|
||||
// Position 1, first frequency (1.0): angle = 1*1 = 1
|
||||
// cos(1) = 0.5403, sin(1) = 0.8415
|
||||
if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 {
|
||||
t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0])
|
||||
}
|
||||
if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 {
|
||||
t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1])
|
||||
}
|
||||
|
||||
// Test negative table
|
||||
negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true)
|
||||
|
||||
// negTable[4095] corresponds to position -1
|
||||
// cos(-1) = cos(1), sin(-1) = -sin(1)
|
||||
if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 {
|
||||
t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0])
|
||||
}
|
||||
if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 {
|
||||
t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1])
|
||||
}
|
||||
|
||||
// negTable[4094] corresponds to position -2
|
||||
// cos(-2) = cos(2), sin(-2) = -sin(2)
|
||||
cos2 := math.Cos(2.0)
|
||||
sin2 := math.Sin(2.0)
|
||||
if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 {
|
||||
t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0])
|
||||
}
|
||||
if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 {
|
||||
t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case
|
||||
func TestPrepareRoPE_QwenImage(t *testing.T) {
|
||||
if !mlx.GPUIsAvailable() {
|
||||
t.Skip("GPU not available")
|
||||
}
|
||||
|
||||
mlx.SetDefaultDeviceCPU()
|
||||
|
||||
// 4x4 patch grid, single image
|
||||
imgH, imgW := int32(4), int32(4)
|
||||
txtLen := int32(5)
|
||||
axesDims := []int32{16, 56, 56}
|
||||
|
||||
cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims)
|
||||
mlx.Eval(cache.ImgFreqs, cache.TxtFreqs)
|
||||
|
||||
// Check shapes
|
||||
imgShape := cache.ImgFreqs.Shape()
|
||||
if imgShape[0] != 16 { // 4*4 patches
|
||||
t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0])
|
||||
}
|
||||
|
||||
// For single image (frame=0), all temporal values should be cos=1, sin=0
|
||||
imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32)
|
||||
mlx.Eval(imgFreqsCPU)
|
||||
imgData := imgFreqsCPU.Data()
|
||||
|
||||
// Check first 16 values of patch 0 (temporal cos/sin pairs)
|
||||
for i := 0; i < 16; i += 2 {
|
||||
cosVal := imgData[i]
|
||||
sinVal := imgData[i+1]
|
||||
if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 {
|
||||
t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal)
|
||||
}
|
||||
if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 {
|
||||
t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal)
|
||||
}
|
||||
}
|
||||
|
||||
cache.ImgFreqs.Free()
|
||||
cache.TxtFreqs.Free()
|
||||
}
|
||||
|
||||
// TestScaleRopePositions verifies the centered position calculation for scale_rope=True
|
||||
func TestScaleRopePositions(t *testing.T) {
|
||||
// For a 4x4 grid with scale_rope=True:
|
||||
// hHalf = 2, wHalf = 2
|
||||
// hNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
||||
// wNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
||||
//
|
||||
// Height positions:
|
||||
// y=0: -(4-2) + 0 = -2
|
||||
// y=1: -(4-2) + 1 = -1
|
||||
// y=2: 2 - 2 = 0
|
||||
// y=3: 3 - 2 = 1
|
||||
//
|
||||
// Same for width
|
||||
|
||||
pH, pW := int32(4), int32(4)
|
||||
hHalf := pH / 2
|
||||
wHalf := pW / 2
|
||||
hNegCount := pH - hHalf
|
||||
wNegCount := pW - wHalf
|
||||
|
||||
expectedH := []int32{-2, -1, 0, 1}
|
||||
expectedW := []int32{-2, -1, 0, 1}
|
||||
|
||||
for y := int32(0); y < pH; y++ {
|
||||
var hPos int32
|
||||
if y < hNegCount {
|
||||
hPos = -(pH - hHalf) + y
|
||||
} else {
|
||||
hPos = y - hNegCount
|
||||
}
|
||||
if hPos != expectedH[y] {
|
||||
t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos)
|
||||
}
|
||||
}
|
||||
|
||||
for x := int32(0); x < pW; x++ {
|
||||
var wPos int32
|
||||
if x < wNegCount {
|
||||
wPos = -(pW - wHalf) + x
|
||||
} else {
|
||||
wPos = x - wNegCount
|
||||
}
|
||||
if wPos != expectedW[x] {
|
||||
t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoPEHeadDimensions verifies the head dimension breakdown
|
||||
func TestRoPEHeadDimensions(t *testing.T) {
|
||||
// axes_dims_rope = [16, 56, 56]
|
||||
// Each dimension uses half the values for frequencies
|
||||
// So we get: 8 + 28 + 28 = 64 frequency values
|
||||
// Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position
|
||||
|
||||
axesDims := []int32{16, 56, 56}
|
||||
expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2)
|
||||
expectedHeadDim := expectedFreqs * 2
|
||||
|
||||
if expectedFreqs != 64 {
|
||||
t.Errorf("expected 64 frequency values, got %d", expectedFreqs)
|
||||
}
|
||||
if expectedHeadDim != 128 {
|
||||
t.Errorf("expected head_dim=128, got %d", expectedHeadDim)
|
||||
}
|
||||
|
||||
// This should match the transformer's attention head dimension
|
||||
// hidden_size = 3072, num_heads = 24
|
||||
// head_dim = 3072 / 24 = 128
|
||||
}
|
||||
|
||||
@@ -1,642 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds Qwen-Image VAE configuration
|
||||
type VAEConfig struct {
|
||||
ZDim int32 `json:"z_dim"` // 16
|
||||
BaseDim int32 `json:"base_dim"` // 96
|
||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
||||
}
|
||||
|
||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
||||
func defaultVAEConfig() *VAEConfig {
|
||||
return &VAEConfig{
|
||||
ZDim: 16,
|
||||
BaseDim: 96,
|
||||
DimMult: []int32{1, 2, 4, 4},
|
||||
NumResBlocks: 2,
|
||||
LatentsMean: []float32{
|
||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632,
|
||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
||||
},
|
||||
LatentsStd: []float32{
|
||||
2.8184, 1.4541, 2.3275, 2.6558,
|
||||
1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579,
|
||||
1.6382, 1.1253, 2.8251, 1.916,
|
||||
},
|
||||
TemperalDownsample: []bool{false, true, true},
|
||||
}
|
||||
}
|
||||
|
||||
// VAE is the full VAE with encoder and decoder
|
||||
type VAE struct {
|
||||
Config *VAEConfig
|
||||
Encoder *VAEEncoder
|
||||
Decoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the VAE from a directory
|
||||
func (m *VAE) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
|
||||
|
||||
cfg := defaultVAEConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Load weights as f32 for quality (matches Python default behavior)
|
||||
// VAE decoder precision is critical for final image quality
|
||||
fmt.Print(" Loading weights as f32... ")
|
||||
if err := weights.Load(mlx.DtypeFloat32); err != nil {
|
||||
return fmt.Errorf("failed to load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
// Load encoder
|
||||
fmt.Print(" Loading encoder... ")
|
||||
m.Encoder = &VAEEncoder{}
|
||||
if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("encoder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load decoder
|
||||
fmt.Print(" Loading decoder... ")
|
||||
m.Decoder = &VAEDecoder{}
|
||||
if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("decoder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes an image to latents
|
||||
// x: [B, C, T, H, W] image tensor in [-1, 1] range
|
||||
// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
|
||||
func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
|
||||
return m.Encoder.Encode(x)
|
||||
}
|
||||
|
||||
// Decode decodes latents to image
|
||||
// z: [B, C, T, H, W] latents (denormalized)
|
||||
// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
|
||||
func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
|
||||
return m.Decoder.Decode(z)
|
||||
}
|
||||
|
||||
// Normalize applies latent normalization
|
||||
// Input z should be f32 (from VAE encoder), output is f32 for transformer
|
||||
func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
// Mean/std are f32, will match z dtype through broadcasting
|
||||
return mlx.Div(mlx.Sub(z, mean), std)
|
||||
}
|
||||
|
||||
// Denormalize reverses latent normalization
|
||||
// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
|
||||
func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
// Convert latents to f32 for VAE decoder quality
|
||||
z = mlx.AsType(z, mlx.DtypeFloat32)
|
||||
|
||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
return mlx.Add(mlx.Mul(z, std), mean)
|
||||
}
|
||||
|
||||
// VAEEncoder is the encoder part of the VAE
|
||||
// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
|
||||
// - Blocks 0,1: ResBlocks (base_dim)
|
||||
// - Block 2: Downsample
|
||||
// - Blocks 3,4: ResBlocks (base_dim*2)
|
||||
// - Block 5: Downsample + temporal
|
||||
// - Blocks 6,7: ResBlocks (base_dim*4)
|
||||
// - Block 8: Downsample + temporal
|
||||
// - Blocks 9,10: ResBlocks (base_dim*4)
|
||||
type VAEEncoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
ConvIn *CausalConv3d
|
||||
Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
|
||||
MidBlock *MidBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
QuantConv *CausalConv3d
|
||||
}
|
||||
|
||||
// EncoderBlock is either a ResBlock or a Downsample
|
||||
type EncoderBlock interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
IsDownsample() bool
|
||||
}
|
||||
|
||||
// EncoderResBlock wraps ResBlock
|
||||
type EncoderResBlock struct {
|
||||
*ResBlock
|
||||
}
|
||||
|
||||
func (b *EncoderResBlock) IsDownsample() bool { return false }
|
||||
|
||||
// EncoderDownsample is a downsample layer
|
||||
type EncoderDownsample struct {
|
||||
Resample *CausalConv3d
|
||||
TimeConv *CausalConv3d // Optional temporal downsample
|
||||
}
|
||||
|
||||
func (d *EncoderDownsample) IsDownsample() bool { return true }
|
||||
|
||||
func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Spatial downsample with stride 2
|
||||
// WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
|
||||
x = d.forwardSpatialDownsample(x)
|
||||
|
||||
// NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
|
||||
// with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
|
||||
// The Python forward checks: if feat_cache is not None ... then use time_conv
|
||||
// Since we don't support streaming, we skip time_conv entirely.
|
||||
return x
|
||||
}
|
||||
|
||||
// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
|
||||
func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
|
||||
xShape := x.Shape()
|
||||
B := xShape[0]
|
||||
T := xShape[1]
|
||||
H := xShape[2]
|
||||
W := xShape[3]
|
||||
C := xShape[4]
|
||||
|
||||
wShape := d.Resample.Weight.Shape()
|
||||
outC := wShape[0]
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
|
||||
// ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
|
||||
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
|
||||
|
||||
// Apply 2D conv with stride 2
|
||||
weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
||||
x = conv2DStrided(x, weight, 2)
|
||||
|
||||
if d.Resample.Bias != nil {
|
||||
bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
// Output dims after stride 2: (H+1)/2, (W+1)/2
|
||||
outH := (H + 1) / 2
|
||||
outW := (W + 1) / 2
|
||||
|
||||
// Reshape back to [B, T, H', W', C]
|
||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// loadFromWeights loads the encoder from pre-loaded weights
|
||||
func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
||||
e.Config = cfg
|
||||
|
||||
// Conv in
|
||||
convIn, err := newCausalConv3d(weights, "encoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.ConvIn = convIn
|
||||
|
||||
// Encoder uses flat block structure:
|
||||
// dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
|
||||
// Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
|
||||
// That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
|
||||
e.Blocks = make([]EncoderBlock, 0, 11)
|
||||
|
||||
// Track dimensions
|
||||
dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
|
||||
blockIdx := 0
|
||||
|
||||
for stage := 0; stage < len(cfg.DimMult); stage++ {
|
||||
inDim := cfg.BaseDim
|
||||
if stage > 0 {
|
||||
inDim = dims[stage-1]
|
||||
}
|
||||
outDim := dims[stage]
|
||||
|
||||
// ResBlocks for this stage (num_res_blocks per stage)
|
||||
for r := int32(0); r < cfg.NumResBlocks; r++ {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
||||
currentInDim := inDim
|
||||
if r > 0 {
|
||||
currentInDim = outDim
|
||||
}
|
||||
block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
|
||||
}
|
||||
e.Blocks = append(e.Blocks, block)
|
||||
blockIdx++
|
||||
}
|
||||
|
||||
// Downsample after each stage except the last
|
||||
if stage < len(cfg.DimMult)-1 {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
||||
down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
|
||||
}
|
||||
e.Blocks = append(e.Blocks, down)
|
||||
blockIdx++
|
||||
}
|
||||
}
|
||||
|
||||
// Mid block
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.MidBlock = midBlock
|
||||
|
||||
// Norm out
|
||||
normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.NormOut = normOut
|
||||
|
||||
// Conv out
|
||||
convOut, err := newCausalConv3d(weights, "encoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.ConvOut = convOut
|
||||
|
||||
// Quant conv
|
||||
quantConv, err := newCausalConv3d(weights, "quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.QuantConv = quantConv
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
|
||||
func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
|
||||
block, err := newResBlock(weights, prefix, inDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &EncoderResBlock{block}, nil
|
||||
}
|
||||
|
||||
// newEncoderDownsample creates a downsample layer for the encoder
|
||||
func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
|
||||
resample, err := newCausalConv3d(weights, prefix+".resample.1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var timeConv *CausalConv3d
|
||||
if temporal {
|
||||
timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
|
||||
}
|
||||
|
||||
return &EncoderDownsample{
|
||||
Resample: resample,
|
||||
TimeConv: timeConv,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encode encodes an image to latents
|
||||
// x: [B, C, T, H, W] image tensor (channels-first)
|
||||
// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
|
||||
func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
|
||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(x)
|
||||
|
||||
// Conv in
|
||||
x = e.ConvIn.Forward(x)
|
||||
|
||||
// Encoder blocks (mix of ResBlocks and Downsamplers)
|
||||
for _, block := range e.Blocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Mid block
|
||||
x = e.MidBlock.Forward(x)
|
||||
|
||||
// Norm + silu
|
||||
{
|
||||
prev := x
|
||||
x = e.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Conv out
|
||||
{
|
||||
prev := x
|
||||
x = e.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Quant conv
|
||||
{
|
||||
prev := x
|
||||
x = e.QuantConv.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Get mode from distribution (first half of channels = mean)
|
||||
// Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
|
||||
shape := x.Shape()
|
||||
latentC := shape[4] / 2
|
||||
x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
|
||||
|
||||
// Convert back to channels-first [N, C, T, H, W]
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the decoder part of the VAE
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
PostQuantConv *CausalConv3d
|
||||
ConvIn *CausalConv3d
|
||||
MidBlock *MidBlock
|
||||
UpBlocks []*UpBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
}
|
||||
|
||||
// loadFromWeights loads the decoder from pre-loaded weights
|
||||
func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
||||
d.Config = cfg
|
||||
|
||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.PostQuantConv = postQuantConv
|
||||
|
||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ConvIn = convIn
|
||||
|
||||
// Mid block
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.MidBlock = midBlock
|
||||
|
||||
// Up blocks (reversed dim_mult)
|
||||
numUpBlocks := len(cfg.DimMult)
|
||||
d.UpBlocks = make([]*UpBlock, numUpBlocks)
|
||||
|
||||
dimsMult := make([]int32, numUpBlocks+1)
|
||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
||||
}
|
||||
|
||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
||||
for i := range cfg.TemperalDownsample {
|
||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
||||
}
|
||||
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
inDim := cfg.BaseDim * dimsMult[i]
|
||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
||||
|
||||
if i > 0 {
|
||||
inDim = inDim / 2
|
||||
}
|
||||
|
||||
upsampleMode := ""
|
||||
if i < numUpBlocks-1 {
|
||||
if temporalUpsample[i] {
|
||||
upsampleMode = "upsample3d"
|
||||
} else {
|
||||
upsampleMode = "upsample2d"
|
||||
}
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.UpBlocks[i] = upBlock
|
||||
}
|
||||
|
||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.NormOut = normOut
|
||||
|
||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ConvOut = convOut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode converts latents to image
|
||||
// z: [B, C, T, H, W] denormalized latents
|
||||
func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
||||
var x *mlx.Array
|
||||
|
||||
// Convert from channels-first to channels-last
|
||||
{
|
||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(z)
|
||||
}
|
||||
|
||||
// PostQuantConv
|
||||
x = d.PostQuantConv.Forward(z)
|
||||
z.Free()
|
||||
|
||||
// ConvIn
|
||||
{
|
||||
prev := x
|
||||
x = d.ConvIn.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Mid block
|
||||
x = d.MidBlock.Forward(x)
|
||||
|
||||
// Up blocks
|
||||
for _, upBlock := range d.UpBlocks {
|
||||
x = upBlock.Forward(x)
|
||||
}
|
||||
|
||||
// NormOut + silu
|
||||
{
|
||||
prev := x
|
||||
x = d.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// ConvOut
|
||||
{
|
||||
prev := x
|
||||
x = d.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Post-processing: clamp and convert back to channels-first
|
||||
{
|
||||
prev := x
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// DownBlock handles downsampling in encoder
|
||||
type DownBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Downsampler *Downsample
|
||||
}
|
||||
|
||||
// newDownBlock creates a down block
|
||||
func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var downsampler *Downsample
|
||||
if downsampleMode != "" {
|
||||
downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
|
||||
}
|
||||
|
||||
return &DownBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Downsampler: downsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies down block
|
||||
func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, block := range d.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if d.Downsampler != nil {
|
||||
prev := x
|
||||
x = d.Downsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Downsample handles spatial downsampling
|
||||
type Downsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newDownsample creates a downsampler
|
||||
func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Downsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies downsampling to channels-last input [B, T, H, W, C]
|
||||
func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := d.Conv.Shape()[0]
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
|
||||
// For 3x3 stride 2: pad 1 on all sides
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
|
||||
// Conv with stride 2 using manual strided patching
|
||||
weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
|
||||
x = conv2DStrided(x, weight, 2)
|
||||
if d.Bias != nil {
|
||||
bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
x = mlx.Reshape(x, B, T, H/2, W/2, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
@@ -9,7 +9,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// modelConfig represents the HuggingFace config.json structure
|
||||
@@ -35,22 +36,22 @@ type modelConfig struct {
|
||||
|
||||
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
|
||||
// It reads the config.json layer and returns a map compatible with GGML's KV format.
|
||||
func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
var config modelConfig
|
||||
if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
|
||||
if err := mf.ReadConfigJSON("config.json", &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
// Calculate total tensor bytes from manifest layers
|
||||
var totalBytes int64
|
||||
var tensorCount int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
totalBytes += layer.Size
|
||||
tensorCount++
|
||||
}
|
||||
@@ -151,27 +152,30 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
|
||||
|
||||
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
||||
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
||||
func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
return getTensorInfoFromManifest(manifest)
|
||||
return getTensorInfoFromManifest(mf)
|
||||
}
|
||||
|
||||
// getTensorInfoFromManifest extracts tensor info from a manifest.
|
||||
// This is separated for testability.
|
||||
func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
|
||||
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
|
||||
var tensors []api.Tensor
|
||||
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType != manifest.MediaTypeImageTensor {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read the safetensors header from the blob
|
||||
blobPath := manifest.BlobPath(layer.Digest)
|
||||
blobPath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
info, err := readSafetensorsHeader(blobPath)
|
||||
if err != nil {
|
||||
// Skip tensors we can't read
|
||||
@@ -197,15 +201,15 @@ func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor,
|
||||
// GetSafetensorsDtype returns the quantization type for a safetensors model.
|
||||
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
|
||||
// Otherwise returns the torch_dtype from config.json.
|
||||
func GetSafetensorsDtype(modelName string) (string, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Check if model is quantized by looking for _scale tensors
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
if strings.HasSuffix(layer.Name, "_scale") {
|
||||
// Model is quantized - return FP8 (affine quantization)
|
||||
return "FP8", nil
|
||||
@@ -217,7 +221,7 @@ func GetSafetensorsDtype(modelName string) (string, error) {
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
}
|
||||
if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
|
||||
if err := mf.ReadConfigJSON("config.json", &cfg); err != nil {
|
||||
return "", fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
)
|
||||
|
||||
func TestBuildModelInfo(t *testing.T) {
|
||||
@@ -451,8 +451,14 @@ func TestParseSafetensorsHeader_Errors(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
// Create a temp directory for blobs
|
||||
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
blobDir := filepath.Join(tempDir, "blobs")
|
||||
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Create test tensor blobs
|
||||
tensors := []struct {
|
||||
@@ -463,26 +469,26 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "model.embed_tokens.weight",
|
||||
digest: "sha256:abc123",
|
||||
digest: "sha256:abc123abc123abc123abc123abc123abc123abc123abc123abc123abc123abc0",
|
||||
dtype: "BF16",
|
||||
shape: []int64{262144, 2560},
|
||||
},
|
||||
{
|
||||
name: "model.layers.0.self_attn.q_proj.weight",
|
||||
digest: "sha256:def456",
|
||||
digest: "sha256:def456def456def456def456def456def456def456def456def456def456def0",
|
||||
dtype: "BF16",
|
||||
shape: []int64{2560, 2560},
|
||||
},
|
||||
{
|
||||
name: "model.norm.weight",
|
||||
digest: "sha256:ghi789",
|
||||
digest: "sha256:789789789789789789789789789789789789789789789789789789789789abc0",
|
||||
dtype: "F32",
|
||||
shape: []int64{2560},
|
||||
},
|
||||
}
|
||||
|
||||
// Create blob files
|
||||
var layers []imagegen.ManifestLayer
|
||||
var layers []manifest.Layer
|
||||
for _, tensor := range tensors {
|
||||
// Create safetensors blob
|
||||
header := map[string]any{
|
||||
@@ -498,15 +504,17 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
// Write blob file
|
||||
blobName := "sha256-" + tensor.digest[7:]
|
||||
blobPath := filepath.Join(tempDir, blobName)
|
||||
// Write blob file using the digest format expected by GetBlobsPath
|
||||
blobPath, err := manifest.BlobsPath(tensor.digest)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob: %v", err)
|
||||
}
|
||||
|
||||
layers = append(layers, imagegen.ManifestLayer{
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
layers = append(layers, manifest.Layer{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: tensor.digest,
|
||||
Size: int64(buf.Len() + 1000), // header + fake data
|
||||
Name: tensor.name,
|
||||
@@ -514,21 +522,20 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add a non-tensor layer (should be skipped)
|
||||
layers = append(layers, imagegen.ManifestLayer{
|
||||
layers = append(layers, manifest.Layer{
|
||||
MediaType: "application/vnd.ollama.image.json",
|
||||
Digest: "sha256:config",
|
||||
Digest: "sha256:0000000000000000000000000000000000000000000000000000000000000000",
|
||||
Size: 100,
|
||||
Name: "config.json",
|
||||
})
|
||||
|
||||
manifest := &imagegen.ModelManifest{
|
||||
Manifest: &imagegen.Manifest{
|
||||
Layers: layers,
|
||||
},
|
||||
BlobDir: tempDir,
|
||||
mf := &manifest.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Layers: layers,
|
||||
}
|
||||
|
||||
result, err := getTensorInfoFromManifest(manifest)
|
||||
result, err := getTensorInfoFromManifest(mf)
|
||||
if err != nil {
|
||||
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user