Compare commits

...

14 Commits

Author SHA1 Message Date
Patrick Devine
5818394002 gofumpt the gofmt 2026-01-20 17:35:23 -08:00
Patrick Devine
51631bcba0 replace modelpath.go w/ types/model/name.go 2026-01-20 17:14:54 -08:00
Patrick Devine
6e43944c94 housekeeping: move manifest implementation 2026-01-20 16:19:21 -08:00
next-n
d6dd430abd x/imagegen: respect OLLAMA_MODELS for manifests and blobs (#13797) 2026-01-20 13:01:52 -08:00
Daniel Hiltgen
ae78112c50 test: add lfm2.5-thinking coverage (#13802) 2026-01-20 12:57:02 -08:00
Jeffrey Morgan
01cf7445f3 model: add lfm2 architecture and LFM2.5-1.2B-Thinking support (#13792)
Co-Authored-By: TommyBoiss <165361500+TommyBoiss@users.noreply.github.com>
2026-01-20 12:20:53 -08:00
Jeffrey Morgan
31085d5e53 fix: use api.GenerateRequest for image generation test (#13793)
Remove non-existent x/imagegen/api import and use the standard
api.GenerateRequest/GenerateResponse with the Image field instead.
2026-01-20 03:23:31 -08:00
Daniel Hiltgen
c42e9d244f test: add image gen test case (#13698)
* test: fix type regression in tools test.

* test: add image gen integration test
2026-01-19 16:01:31 -08:00
Devon Rifkin
e98b5e8b4e /api/show: default to empty model_info (#13785)
For `/api/show`, a fully missing `model_info` field trips up various
integrators (including a recent Android Studio integration).

The primary source of missing info tends to come from models with a
remote that are also missing other data. It seems better to me to return
an empty `model_info` than making up some other fields within
`model_info` (like saying the architecture is `remote` or something like
that). So this does slightly change `/api/show`'s behavior that possibly
someone is relying on, but it seems more important to ensure the field
is always there (from a quick sampling integrations seem to be robust to
missing fields _within_ it).

Fixes: https://github.com/ollama/ollama/issues/13783
2026-01-19 15:26:17 -08:00
Jeffrey Morgan
68e00c7c36 fix: prevent image generation models from loading during deletion (#13781)
Move the unload check (empty prompt + KeepAlive=0) before the image
generation model dispatch in GenerateHandler. This prevents models like
flux from being loaded into memory just to be immediately unloaded when
running `ollama rm`.

Also fix a bug in DeleteHandler where `args[0]` was used instead of
`arg` in the delete loop, causing only the first model to be unloaded
when deleting multiple models.
2026-01-19 12:48:34 -08:00
Jeffrey Morgan
4f138a1749 model: add Glm4MoeLiteForCausalLM architecture to support GLM-4.7-Flash (#13779) 2026-01-19 12:47:17 -08:00
Jeffrey Morgan
03bf241c33 x/imagegen: add FP4 quantization support for image generation models (#13773)
Add --quantize fp4 support to ollama create for image generation models
(flux2, z-image-turbo), using MLX's affine 4-bit quantization.

Changes:
- Add fp4 to validation in CreateImageGenModel
- Add FP4 case to quantizeTensor (group_size=32, bits=4, affine mode)
- Add GetQuantization() to WeightSource interface for dynamic params
- Update LoadLinearLayer to use quantization params from model metadata
2026-01-19 00:54:54 -08:00
Jeffrey Morgan
a887406c24 x/imagegen: add preliminary support for FLUX.2-klein model (#13772) 2026-01-18 22:30:49 -08:00
Jeffrey Morgan
d51e95ba7e server: prevent image generation models from reloading on every request (#13771)
The loadImageGen function was not setting Options on the runnerRef,
causing needsReload() to always return true (since it checks if
runner.Options == nil). This resulted in the image generation
subprocess being killed and restarted for every request.
2026-01-18 20:50:04 -08:00
82 changed files with 10047 additions and 1019 deletions

View File

@@ -749,7 +749,7 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ModelInfo map[string]any `json:"model_info"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`

View File

@@ -899,11 +899,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
for _, arg := range args {
// Unload the model if it's running before deletion
if err := loadOrUnloadModel(cmd, &runOptions{
Model: args[0],
Model: arg,
KeepAlive: &api.Duration{Duration: 0},
}); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", arg)
}
}

View File

@@ -311,6 +311,10 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &deepseekocr{}
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
case "Lfm2ForCausalLM":
conv = &lfm2Model{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

View File

@@ -0,0 +1,150 @@
package convert
import (
"cmp"
"fmt"
"log/slog"
"regexp"
"strconv"
"github.com/ollama/ollama/fs/ggml"
)
type glm4MoeLiteModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
KVLoraRank uint32 `json:"kv_lora_rank"`
QLoraRank uint32 `json:"q_lora_rank"`
VHeadDim uint32 `json:"v_head_dim"`
ExpertCount uint32 `json:"n_routed_experts"`
ExpertSharedCount uint32 `json:"n_shared_experts"`
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
ExpertWeightsNorm bool `json:"norm_topk_prob"`
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
}
func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "glm4moelite"
kv["general.type"] = "model"
kv["glm4moelite.block_count"] = p.HiddenLayers
numHeads := p.NumAttentionHeads
numKVHeads := p.NumKeyValueHeads
kv["glm4moelite.attention.head_count"] = numHeads
kv["glm4moelite.attention.head_count_kv"] = numKVHeads
kv["glm4moelite.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
kv["glm4moelite.attention.kv_lora_rank"] = p.KVLoraRank
kv["glm4moelite.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["glm4moelite.attention.q_lora_rank"] = p.QLoraRank
kv["glm4moelite.attention.value_length"] = p.VHeadDim
kv["glm4moelite.context_length"] = p.MaxPositionEmbeddings
kv["glm4moelite.embedding_length"] = p.HiddenSize
kv["glm4moelite.expert_count"] = p.ExpertCount
kv["glm4moelite.expert_feed_forward_length"] = p.ExpertIntermediateSize
kv["glm4moelite.expert_shared_count"] = p.ExpertSharedCount
kv["glm4moelite.expert_gating_func"] = uint32(2)
kv["glm4moelite.expert_used_count"] = p.ExpertUsedCount
kv["glm4moelite.expert_weights_norm"] = p.ExpertWeightsNorm
kv["glm4moelite.expert_weights_scale"] = p.ExpertWeightsScale
kv["glm4moelite.feed_forward_length"] = p.IntermediateSize
kv["glm4moelite.leading_dense_block_count"] = p.LeadingDenseBlockCount
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
kv["tokenizer.ggml.pre"] = "glm4"
return kv
}
func (p *glm4MoeLiteModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
"self_attn.kv_b_proj", "attn_kv_b",
"self_attn.q_a_proj", "attn_q_a",
"self_attn.q_a_layernorm", "attn_q_a_norm",
"self_attn.q_b_proj", "attn_q_b",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
"mlp.gate", "ffn_gate_inp",
}
}
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
skipLayer := func(n string, minValue uint32) bool {
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(n)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return uint32(blkNum) >= minValue
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
// skip any additional layers (such as the Multi-Token Prediction layer)
if skipLayer(t.Name(), p.HiddenLayers) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}

100
convert/convert_lfm2.go Normal file
View File

@@ -0,0 +1,100 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type lfm2Model struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
NormEps float32 `json:"norm_eps"`
ConvLCache uint32 `json:"conv_L_cache"`
LayerTypes []string `json:"layer_types"`
TieEmbedding bool `json:"tie_embedding"`
}
var _ ModelConverter = (*lfm2Model)(nil)
func (p *lfm2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "lfm2"
kv["lfm2.vocab_size"] = p.VocabSize
kv["lfm2.block_count"] = p.NumHiddenLayers
kv["lfm2.embedding_length"] = p.HiddenSize
kv["lfm2.feed_forward_length"] = p.IntermediateSize
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
// Build per-layer KV head count array based on layer_types
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
for i := range p.NumHiddenLayers {
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
kvHeadCounts[i] = p.NumKeyValueHeads
}
}
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
kv["lfm2.rope.freq_base"] = p.RopeTheta
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
return kv
}
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
shape := t.Shape()
// Squeeze conv weights: [D, 1, K] -> [D, K]
if strings.HasSuffix(t.Name(), "shortconv.conv.weight") {
if len(shape) == 3 && shape[1] == 1 {
shape = []uint64{shape[0], shape[2]}
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
return out
}
func (p *lfm2Model) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.embedding_norm", "output_norm",
"model.layers", "blk",
"operator_norm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_output",
"self_attn.q_layernorm", "attn_q_norm",
"self_attn.k_layernorm", "attn_k_norm",
"conv.conv", "shortconv.conv",
"conv.in_proj", "shortconv.in_proj",
"conv.out_proj", "shortconv.out_proj",
"feed_forward.w1", "ffn_gate",
"feed_forward.w2", "ffn_down",
"feed_forward.w3", "ffn_up",
"ffn_norm", "ffn_norm",
}
}

View File

@@ -40,6 +40,7 @@ const (
func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
strings.HasSuffix(t.name, ".bias") ||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||

View File

@@ -269,6 +269,8 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
"lfm2",
}, kv.Architecture())
}
@@ -856,7 +858,9 @@ func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"bert",
"gemma3",
"glm4moelite",
"gptoss", "gpt-oss",
"lfm2",
"mistral3",
"olmo3",
"qwen3", "qwen3moe",

View File

@@ -0,0 +1,148 @@
//go:build integration
package integration
import (
"context"
"encoding/base64"
"fmt"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestImageGeneration(t *testing.T) {
skipUnderMinVRAM(t, 8)
type testCase struct {
imageGenModel string
visionModel string
prompt string
expectedWords []string
}
testCases := []testCase{
{
imageGenModel: "jmorgan/z-image-turbo",
visionModel: "llama3.2-vision",
prompt: "A cartoon style llama flying like a superhero through the air with clouds in the background",
expectedWords: []string{"llama", "flying", "cartoon", "cloud", "sky", "superhero", "air", "animal", "camelid"},
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s->%s", tc.imageGenModel, tc.visionModel), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Pull both models
if err := PullIfMissing(ctx, client, tc.imageGenModel); err != nil {
t.Fatalf("failed to pull image gen model: %v", err)
}
if err := PullIfMissing(ctx, client, tc.visionModel); err != nil {
t.Fatalf("failed to pull vision model: %v", err)
}
// Generate the image
t.Logf("Generating image with prompt: %s", tc.prompt)
imageBase64, err := generateImage(ctx, client, tc.imageGenModel, tc.prompt)
if err != nil {
if strings.Contains(err.Error(), "image generation not available") {
t.Skip("Target system does not support image generation")
} else if strings.Contains(err.Error(), "executable file not found in") { // Windows pattern, not yet supported
t.Skip("Windows does not support image generation yet")
} else if strings.Contains(err.Error(), "CUDA driver version is insufficient") {
t.Skip("Driver is too old")
} else if strings.Contains(err.Error(), "insufficient memory for image generation") {
t.Skip("insufficient memory for image generation")
} else if strings.Contains(err.Error(), "error while loading shared libraries: libcuda.so.1") { // AMD GPU or CPU
t.Skip("CUDA GPU is not available")
} else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") {
// most likely linux arm - not supported yet
t.Skip("unsupported architecture")
}
t.Fatalf("failed to generate image: %v", err)
}
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
t.Fatalf("failed to decode image: %v", err)
}
t.Logf("Generated image: %d bytes", len(imageData))
// Preload vision model and check GPU loading
err = client.Generate(ctx, &api.GenerateRequest{Model: tc.visionModel}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load vision model: %v", err)
}
// Use vision model to describe the image
chatReq := api.ChatRequest{
Model: tc.visionModel,
Messages: []api.Message{
{
Role: "user",
Content: "Describe this image in detail. What is shown? What style is it? What is the main subject doing?",
Images: []api.ImageData{imageData},
},
},
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}
// Verify the vision model's response contains expected keywords
response := DoChat(ctx, t, client, chatReq, tc.expectedWords, 240*time.Second, 30*time.Second)
if response != nil {
t.Logf("Vision model response: %s", response.Content)
// Additional detailed check for keywords
content := strings.ToLower(response.Content)
foundWords := []string{}
missingWords := []string{}
for _, word := range tc.expectedWords {
if strings.Contains(content, word) {
foundWords = append(foundWords, word)
} else {
missingWords = append(missingWords, word)
}
}
t.Logf("Found keywords: %v", foundWords)
if len(missingWords) > 0 {
t.Logf("Missing keywords (at least one was found so test passed): %v", missingWords)
}
}
})
}
}
// generateImage calls the Ollama API to generate an image and returns the base64 image data
func generateImage(ctx context.Context, client *api.Client, model, prompt string) (string, error) {
var imageBase64 string
err := client.Generate(ctx, &api.GenerateRequest{
Model: model,
Prompt: prompt,
}, func(resp api.GenerateResponse) error {
if resp.Image != "" {
imageBase64 = resp.Image
}
return nil
})
if err != nil {
return "", fmt.Errorf("failed to generate image: %w", err)
}
if imageBase64 == "" {
return "", fmt.Errorf("no image data in response")
}
return imageBase64, nil
}

View File

@@ -38,6 +38,7 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"lfm2.5-thinking",
"ministral-3",
"qwen3-coder:30b",
"gpt-oss:20b",
@@ -143,6 +144,7 @@ var (
"granite3.3",
"hermes3",
"internlm2",
"lfm2.5-thinking",
"llama-guard3",
"llama-pro",
"llama2-chinese",
@@ -263,6 +265,7 @@ var (
"snowflake-arctic-embed2",
}
libraryToolsModels = []string{
"lfm2.5-thinking",
"qwen3-vl",
"gpt-oss:20b",
"gpt-oss:120b",

View File

@@ -1,4 +1,4 @@
package server
package manifest
import (
"crypto/sha256"
@@ -14,7 +14,7 @@ type Layer struct {
Size int64 `json:"size"`
From string `json:"from,omitempty"`
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
status string
Status string `json:"-"`
}
const (
@@ -65,7 +65,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
MediaType: mediatype,
Digest: digest,
Size: n,
status: fmt.Sprintf("%s %s", status, digest),
Status: fmt.Sprintf("%s %s", status, digest),
}, nil
}
@@ -89,7 +89,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
Digest: digest,
Size: fi.Size(),
From: from,
status: fmt.Sprintf("using existing layer %s", digest),
Status: fmt.Sprintf("using existing layer %s", digest),
}, nil
}

View File

@@ -1,10 +1,9 @@
package server
package manifest
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
@@ -33,6 +32,32 @@ func (m *Manifest) Size() (size int64) {
return
}
func (m *Manifest) Digest() string {
return m.digest
}
func (m *Manifest) FileInfo() os.FileInfo {
return m.fi
}
// ReadConfigJSON reads and unmarshals a config layer as JSON.
func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
for _, layer := range m.Layers {
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
blobPath, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
data, err := os.ReadFile(blobPath)
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
}
return fmt.Errorf("config %q not found in manifest", configPath)
}
func (m *Manifest) Remove() error {
if err := os.Remove(m.filepath); err != nil {
return err
@@ -74,7 +99,7 @@ func (m *Manifest) RemoveLayers() error {
if err != nil {
return err
}
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
if err := os.Remove(blob); os.IsNotExist(err) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err

View File

@@ -1,4 +1,4 @@
package server
package manifest
import (
"encoding/json"

95
manifest/paths.go Normal file
View File

@@ -0,0 +1,95 @@
package manifest
import (
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
var ErrInvalidDigestFormat = errors.New("invalid digest format")
func GetManifestPath() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// GetManifestPathForName returns the path to the manifest file for a specific model name.
func GetManifestPathForName(n model.Name) (string, error) {
if !n.IsValid() {
return "", os.ErrNotExist
}
manifests, err := GetManifestPath()
if err != nil {
return "", err
}
return filepath.Join(manifests, n.Filepath()), nil
}
func GetBlobsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "blobs", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// PruneDirectory removes empty directories recursively.
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}

View File

@@ -162,6 +162,7 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
SSMConv(ctx Context, kernel Tensor) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor

View File

@@ -1641,6 +1641,13 @@ func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2,
return tt
}
func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t),
}
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,

View File

@@ -0,0 +1,304 @@
package glm4moelite
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
numExpertsUsed int
numExperts int
normTopKProb bool
routedScalingFactor float32
kvLoraRank,
qkNopeHeadDim,
qkRopeHeadDim,
kqNopeHeadDim,
qkHeadDim int
qLoraRank int
vHeadDim int
hiddenSize,
numHeads,
numKVHeads int
eps,
ropeBase float32
kqScale float64
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1.0)
}
type Attention struct {
Q *nn.Linear `gguf:"attn_q"`
QA *nn.Linear `gguf:"attn_q_a"`
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
QB *nn.Linear `gguf:"attn_q_b"`
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
KVB *nn.Linear `gguf:"attn_kv_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
if opts.qLoraRank == 0 {
query = attn.Q.Forward(ctx, hiddenStates)
} else {
query = attn.QA.Forward(ctx, hiddenStates)
query = attn.QANorm.Forward(ctx, query, opts.eps)
query = attn.QB.Forward(ctx, query)
}
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
kRot := compressedKV.View(ctx,
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
compressedKV.Stride(1), 1,
compressedKV.Stride(1), compressedKV.Dim(1),
)
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
}
type MLP interface {
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
}
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
SharedExpert *dense `gguf:",suf:_shexp"`
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
}
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = hiddenStates.SILU(ctx, upStates)
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
experts = experts.Mul(ctx, topKWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates
}
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
if moe.ExpProbsBias != nil {
scores = scores.Add(ctx, moe.ExpProbsBias)
}
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
return topKIndices
}
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
residuals := hiddenStates
routerLogits := moe.Router.Forward(ctx, hiddenStates)
scores := routerLogits.Sigmoid(ctx)
topKIndices := moe.topKIndices(ctx, scores, opts)
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
if opts.normTopKProb {
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
}
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
return hiddenStates
}
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *Attention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP MLP
}
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count"))
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
for i := range layers {
if i < firstDenseLayerIndex {
layers[i].MLP = &dense{}
} else {
layers[i].MLP = &sparse{}
}
}
keyLength := int(c.Uint("attention.key_length"))
valueLength := int(c.Uint("attention.value_length"))
kqScale := 1.0 / math.Sqrt(float64(keyLength))
var pre []string
switch c.String("tokenizer.ggml.pre") {
case "glm4":
pre = []string{
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
}
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
pre...,
),
Layers: layers,
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("expert_weights_norm", true),
qLoraRank: int(c.Uint("attention.q_lora_rank")),
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
qkHeadDim: keyLength,
vHeadDim: valueLength,
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
routedScalingFactor: c.Float("expert_weights_scale"),
kqScale: kqScale,
},
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("glm4moelite", New)
}

410
model/models/lfm2/cache.go Normal file
View File

@@ -0,0 +1,410 @@
package lfm2
import (
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var _ kvcache.Cache = (*HybridCache)(nil)
// HybridCache stores:
// - a standard causal KV cache for attention layers
// - a per-sequence recurrent conv state for shortconv layers
//
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
hiddenSize int
dConv int
// slot mapping for recurrent state
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
// track any error from EnsureWritable to propagate later
writableError error
}
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
hiddenSize: hiddenSize,
dConv: dConv,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
}
}
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for shortconv.
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// When reserving memory for estimation, use fake slot assignments
// without modifying permanent state (slotForSeq, refCount)
if reserve {
c.curSlots = c.curSlots[:0]
slots := make([]int32, nSeqs)
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
slots[i] = int32(i)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
var newSlots []int // track newly allocated slots that need zeroing
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
// Zero conv state for newly allocated slots to clear stale data from previous sequences
if len(newSlots) > 0 {
c.zeroConvSlots(ctx, newSlots)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
// Bounds check before freeing
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroConvSlots zeros the conv state for the given slots across all layers.
// This must be called when recycling slots to prevent stale state from affecting new sequences.
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 || len(c.convStates) == 0 {
return
}
// Use input context for creating tensors
inputCtx := ctx.Input()
// Create slot indices tensor
slotIndices := make([]int32, len(slots))
for i, s := range slots {
slotIndices[i] = int32(s)
}
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
// Create zero tensor for the slots (SetRows requires F32 source)
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
// Zero each layer's conv state for these slots
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
// Returns an error if slot allocation fails.
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
// Copy existing conv state for all initialized layers
for _, buf := range c.convStates {
// buf: [dConv*hiddenSize, maxSlots]
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
// SetRows requires F32 source
srcF32 := src.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
}
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
// On the first write to dst, EnsureWritable will create a private slot.
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
// Bounds check before decrementing
if dstSlot >= 0 && dstSlot < len(c.refCount) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
// src may not have a slot yet; dst will allocate on demand
return
}
// Bounds check before incrementing
if srcSlot >= 0 && srcSlot < len(c.refCount) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
return c.kv.CanResume(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
// For recurrent state, any removal invalidates the state because
// the state at position N depends on all previous positions.
// Drop the slot mapping so it resets on next use.
slot, ok := c.slotForSeq[seq]
if !ok {
return nil
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
c.convStates[layer] = buf
return buf
}
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
// Returns an error if copy-on-write allocation fails.
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
// newState must have shape [dConv, hiddenSize, nSeqs].
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
// SetRows requires F32 source
srcF32 := src.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
}

View File

@@ -0,0 +1,444 @@
package lfm2
import (
"testing"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
// TestHybridCache tests verify the slot management logic of HybridCache.
// These tests focus on the recurrent state slot allocation, reference counting,
// and copy-on-write semantics without requiring a full ML backend.
// createSlotOnlyCache creates a HybridCache with only the slot management
// fields initialized. Used to test slot logic in isolation.
func createSlotOnlyCache(maxSequences int) *HybridCache {
return &HybridCache{
hiddenSize: 256,
dConv: 3,
maxSequences: maxSequences,
refCount: make([]int, maxSequences),
freeSlots: initFreeSlots(maxSequences),
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
}
}
func initFreeSlots(n int) []int {
slots := make([]int, 0, n)
for i := n - 1; i >= 0; i-- {
slots = append(slots, i)
}
return slots
}
func TestHybridCache_SlotAllocation(t *testing.T) {
cache := createSlotOnlyCache(4)
// Verify initial state
if len(cache.freeSlots) != 4 {
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
}
// Allocate all slots
for range 4 {
slot, err := cache.allocSlot()
if err != nil {
t.Fatalf("allocSlot failed: %v", err)
}
cache.refCount[slot] = 1
}
// Should be full now
if len(cache.freeSlots) != 0 {
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
}
// Trying to allocate another should fail
_, err := cache.allocSlot()
if err != kvcache.ErrKvCacheFull {
t.Errorf("expected ErrKvCacheFull, got %v", err)
}
}
func TestHybridCache_SlotReuse(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate a slot
slot1, _ := cache.allocSlot()
cache.refCount[slot1] = 1
// Free it
cache.refCount[slot1] = 0
cache.freeSlot(slot1)
// Allocate again - should get the same slot back (LIFO)
slot2, _ := cache.allocSlot()
if slot2 != slot1 {
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
}
}
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Simulate sharing slot with seq 2 (copy-on-write style)
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Should share the same slot
if cache.slotForSeq[2] != slot1 {
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
}
// Ref count should be 2
if cache.refCount[slot1] != 2 {
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Share with seq 2
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Unshare seq 2
cache.refCount[slot1]--
delete(cache.slotForSeq, 2)
// Ref count should be back to 1
if cache.refCount[slot1] != 1 {
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
}
// Seq 2 should no longer have a slot
if _, ok := cache.slotForSeq[2]; ok {
t.Error("seq 2 should not have a slot after unshare")
}
}
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
cache := createSlotOnlyCache(4)
initialFreeSlots := len(cache.freeSlots)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Free the slot when refCount drops to 0
cache.refCount[slot1]--
if cache.refCount[slot1] <= 0 {
cache.refCount[slot1] = 0
cache.freeSlot(slot1)
}
delete(cache.slotForSeq, 1)
// Slot should be freed
if len(cache.freeSlots) != initialFreeSlots {
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
}
// Ref count should be 0
if cache.refCount[slot1] != 0 {
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_SlotOverwrite(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slots for seq 1 and seq 2
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
slot2, _ := cache.allocSlot()
cache.slotForSeq[2] = slot2
cache.refCount[slot2] = 1
initialFreeSlots := len(cache.freeSlots)
// Simulate overwriting seq 2's slot with slot1 (sharing)
// First free the old slot
cache.refCount[slot2]--
if cache.refCount[slot2] <= 0 {
cache.refCount[slot2] = 0
cache.freeSlot(slot2)
}
// Then share slot1
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Seq 2 should now share slot1
if cache.slotForSeq[2] != slot1 {
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
}
// Old slot2 should be freed
if len(cache.freeSlots) != initialFreeSlots+1 {
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
}
}
func TestHybridCache_BoundsChecking(t *testing.T) {
cache := createSlotOnlyCache(4)
// Test freeing invalid slot (should not panic)
cache.freeSlot(-1)
cache.freeSlot(100) // out of bounds
// freeSlot does bounds checking, so invalid slots should be ignored
if len(cache.freeSlots) != 4 {
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
}
}
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
cache := createSlotOnlyCache(8)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Fork to seq 2, 3, 4 (all share slot1)
for _, seq := range []int{2, 3, 4} {
cache.slotForSeq[seq] = slot1
cache.refCount[slot1]++
}
// Ref count should be 4
if cache.refCount[slot1] != 4 {
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
}
// Remove seq 2, 3
for _, seq := range []int{2, 3} {
delete(cache.slotForSeq, seq)
cache.refCount[slot1]--
}
if cache.refCount[slot1] != 2 {
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
}
// Slot should still be allocated (not in free list)
found := false
for _, s := range cache.freeSlots {
if s == slot1 {
found = true
break
}
}
if found {
t.Error("slot1 should not be in free list yet")
}
// Remove remaining sequences
for _, seq := range []int{1, 4} {
delete(cache.slotForSeq, seq)
cache.refCount[slot1]--
}
if cache.refCount[slot1] != 0 {
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_ChainedSharing(t *testing.T) {
cache := createSlotOnlyCache(8)
// Create seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Share 1 -> 2
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Share 2 -> 3 (should still share slot1)
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
cache.refCount[slot1]++
// All should share slot1
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
t.Error("all sequences should share slot1")
}
if cache.refCount[slot1] != 3 {
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_CacheParameters(t *testing.T) {
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
if cache.hiddenSize != 512 {
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
}
if cache.dConv != 5 {
t.Errorf("expected dConv 5, got %d", cache.dConv)
}
}
func TestHybridCache_NumSeqs(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially no sequences
if cache.numSeqs() != 0 {
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
}
// Manually set up current batch state
cache.curSeqs = []int{1, 2, 3}
if cache.numSeqs() != 3 {
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
}
}
func TestHybridCache_SeqTokens(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially 0
if cache.seqTokens() != 0 {
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
}
// Manually set up current batch state
cache.curSeqTokens = 16
if cache.seqTokens() != 16 {
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
}
}
// Test that Seqs returns a clone of curSeqs
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
cache := createSlotOnlyCache(4)
cache.curSeqs = []int{1, 2, 3}
seqs := cache.Seqs()
// Modify returned slice
seqs[0] = 999
// Original should be unchanged
if cache.curSeqs[0] != 1 {
t.Error("Seqs should return a clone, not the original slice")
}
}
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially not supported (no batch set up)
if cache.IsSupportedForBatch() {
t.Error("expected IsSupportedForBatch to be false initially")
}
// Set up a valid batch
cache.curSeqTokens = 1
cache.curSeqs = []int{1}
if !cache.IsSupportedForBatch() {
t.Error("expected IsSupportedForBatch to be true with valid batch")
}
}
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
cache := createSlotOnlyCache(4)
// zeroConvSlots should handle empty slots without panicking
cache.zeroConvSlots(nil, nil)
cache.zeroConvSlots(nil, []int{})
// zeroConvSlots should handle empty convStates without panicking
cache.zeroConvSlots(nil, []int{0, 1, 2})
}
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Free the slot (simulating sequence removal)
cache.refCount[slot1]--
cache.freeSlot(slot1)
delete(cache.slotForSeq, 1)
// Verify slot is in free list
if len(cache.freeSlots) != 4 {
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
}
// Allocate for new seq 2 - should get recycled slot
slot2, _ := cache.allocSlot()
if slot2 != slot1 {
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
}
// This recycled slot would need zeroing in the real implementation
// The actual zeroing is tested via integration tests since it requires ML context
}
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
cache := createSlotOnlyCache(4)
// Simulate the slot allocation flow from StartForward
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
newSlots := []int{}
// Seq 1 doesn't have a slot - allocate and track
seq := 1
if _, ok := cache.slotForSeq[seq]; !ok {
slot, err := cache.allocSlot()
if err != nil {
t.Fatalf("allocSlot failed: %v", err)
}
cache.slotForSeq[seq] = slot
cache.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
// Verify newSlots contains the allocated slot
if len(newSlots) != 1 {
t.Errorf("expected 1 new slot, got %d", len(newSlots))
}
// Seq 1 already has a slot - should NOT be tracked as new
newSlots2 := []int{}
if _, ok := cache.slotForSeq[seq]; !ok {
slot, _ := cache.allocSlot()
cache.slotForSeq[seq] = slot
cache.refCount[slot] = 1
newSlots2 = append(newSlots2, slot)
}
// Verify no new slots for existing sequence
if len(newSlots2) != 0 {
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
}
}

253
model/models/lfm2/model.go Normal file
View File

@@ -0,0 +1,253 @@
package lfm2
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
ropeType string
originalContextLength int
// per-layer head counts (LFM2 alternates attention and recurrent layers)
numHeadsByLayer []int
numKVHeadsByLayer []int
}
func (o Options) headDimValue() int {
// Head dim is shared across layers; fall back to first attention layer head count.
for _, h := range o.numHeadsByLayer {
if h > 0 {
return cmp.Or(o.headDim, o.hiddenSize/h)
}
}
return cmp.Or(o.headDim, o.hiddenSize)
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
)
}
headCount := 1
for _, h := range o.numHeadsByLayer {
if h > 0 {
headCount = h
break
}
}
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/headCount), o.ropeBase, 1./o.ropeScale, opts...)
}
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
func New(c fs.Config) (model.Model, error) {
if c.Uint("expert_count") > 0 {
return nil, model.ErrUnsupportedModel
}
if c.String("tokenizer.ggml.model") != "gpt2" {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
var pretokenizers []string
switch c.String("tokenizer.ggml.pre") {
case "default":
// use default BPE pretokenizer
default:
// llama-bpe style (default for LFM2)
pretokenizers = []string{
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
}
}
m := Model{
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeType: c.String("rope.scaling.type"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
},
}
type headCounts interface {
HeadCount() []uint64
HeadCountKV() []uint64
}
hc, ok := c.(headCounts)
if !ok {
return nil, model.ErrUnsupportedModel
}
headCount := hc.HeadCount()
headCountKV := hc.HeadCountKV()
m.numHeadsByLayer = make([]int, len(m.Layers))
m.numKVHeadsByLayer = make([]int, len(m.Layers))
for i := range m.Layers {
m.numHeadsByLayer[i] = int(headCount[i])
m.numKVHeadsByLayer[i] = int(headCountKV[i])
if m.numKVHeadsByLayer[i] == 0 {
m.Layers[i].Operator = &ShortConv{}
} else {
m.Layers[i].Operator = &Attention{}
}
}
lCache := int(c.Uint("shortconv.l_cache"))
dConv := max(0, lCache-1)
m.Cache = NewHybridCache(m.Shift, m.hiddenSize, dConv)
return &m, nil
}
type Operator interface {
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
}
func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
headDim := opts.headDimValue()
numHeads := opts.numHeadsByLayer[layer]
numKVHeads := opts.numKVHeadsByLayer[layer]
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, headDim, numHeads, batchSize)
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Operator Operator
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.Operator.Forward(ctx, hiddenState, positions, cache, layer, opts)
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, outputs, m.Cache.(*HybridCache), &m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("lfm2", New)
}

View File

@@ -0,0 +1,50 @@
package lfm2
import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type shortConvKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// ShortConv implements the LFM2 short-convolution block (GGML_OP_SSM_CONV) with a recurrent
// state stored in the HybridCache.
type ShortConv struct {
Conv *shortConvKernel `gguf:"shortconv.conv"`
InProj *nn.Linear `gguf:"shortconv.in_proj"`
OutProj *nn.Linear `gguf:"shortconv.out_proj"`
}
func (sc *ShortConv) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
nSeqs := cache.numSeqs()
seqTokens := cache.seqTokens()
hiddenSize := hiddenStates.Dim(0)
if nSeqs <= 0 || seqTokens <= 0 || hiddenStates.Dim(1) != nSeqs*seqTokens {
panic("lfm2: unsupported batch layout for shortconv")
}
bcx := sc.InProj.Forward(ctx, hiddenStates).Reshape(ctx, 3*hiddenSize, seqTokens, nSeqs)
elementSize := bcx.Stride(0)
b := bcx.View(ctx, 0*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
c := bcx.View(ctx, 1*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
x := bcx.View(ctx, 2*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
bx := b.Mul(ctx, x).Permute(ctx, 1, 0, 2, 3)
state, err := cache.ConvState(ctx, layer)
if err != nil {
panic("lfm2: failed to get conv state: " + err.Error())
}
sx := state.Concat(ctx, bx, 0)
convOut := sx.SSMConv(ctx, sc.Conv.Weight)
y := c.Mul(ctx, convOut)
dConv := sx.Dim(0) - seqTokens
cache.UpdateConvState(ctx, layer, sx.Slice(ctx, 0, sx.Dim(0)-dConv, sx.Dim(0), 1))
return sc.OutProj.Forward(ctx, y.Reshape(ctx, hiddenSize, seqTokens*nSeqs))
}

View File

@@ -7,7 +7,9 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/glm4moelite"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/lfm2"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"

410
model/parsers/glm46.go Normal file
View File

@@ -0,0 +1,410 @@
package parsers
import (
"context"
"encoding/xml"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type glm46ParserState int
const (
glm46ParserState_LookingForThinkingOpen glm46ParserState = iota
glm46ParserState_ThinkingStartedEatingWhitespace
glm46ParserState_CollectingThinking
glm46ParserState_ThinkingDoneEatingWhitespace
glm46ParserState_CollectingContent
glm46ParserState_ToolStartedEatingWhitespace
glm46ParserState_CollectingToolContent
)
const (
glm46ThinkingOpenTag = "<think>"
glm46ThinkingCloseTag = "</think>"
glm46ToolOpenTag = "<tool_call>"
glm46ToolCloseTag = "</tool_call>"
)
type GLM46Parser struct {
state glm46ParserState
buffer strings.Builder
tools []api.Tool
}
func (p *GLM46Parser) HasToolSupport() bool {
return true
}
func (p *GLM46Parser) HasThinkingSupport() bool {
return true
}
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
return tools
}
type glm46Event interface {
isGLM46Event()
}
type glm46EventContent struct {
content string
}
func (glm46EventContent) isGLM46Event() {}
type glm46EventRawToolCall struct {
raw string
}
func (glm46EventRawToolCall) isGLM46Event() {}
type glm46EventThinkingContent struct {
content string
}
func (glm46EventThinkingContent) isGLM46Event() {}
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case glm46EventRawToolCall:
toolCall, err := parseGLM46ToolCall(event, p.tools)
if err != nil {
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case glm46EventThinkingContent:
thinkingSb.WriteString(event.content)
case glm46EventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content
// events, we naively append them together here.
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *GLM46Parser) parseEvents() []glm46Event {
var all []glm46Event
keepLooping := true
for keepLooping {
var events []glm46Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4.6 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
// and transitions to the next state. Returns (nil, false) if only whitespace remains
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
func (p *GLM46Parser) eatLeadingWhitespaceAndTransitionTo(nextState glm46ParserState) ([]glm46Event, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false // Still only whitespace, keep waiting for more input
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true // Successfully transitioned
}
// glm46SplitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
// the content after (optionally trimmed of leading whitespace), and updates the buffer
func glm46SplitAtTag(p *GLM46Parser, tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after
}
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
var events []glm46Event
switch p.state {
case glm46ParserState_LookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, glm46ThinkingOpenTag) {
// Found <think> opening tag
after := strings.TrimPrefix(trimmed, glm46ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = glm46ParserState_ThinkingStartedEatingWhitespace
} else {
p.state = glm46ParserState_CollectingThinking
}
return events, true
} else if strings.HasPrefix(glm46ThinkingOpenTag, trimmed) {
// Partial opening tag seen, keep accumulating
return events, false
} else if trimmed == "" {
// Only whitespace, keep accumulating
return events, false
} else {
// No thinking tag found, skip to content collection
p.state = glm46ParserState_CollectingContent
// Don't trim - we want to keep the original content
return events, true
}
case glm46ParserState_ThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingThinking)
case glm46ParserState_CollectingThinking:
acc := p.buffer.String()
if strings.Contains(acc, glm46ThinkingCloseTag) {
thinking, remaining := glm46SplitAtTag(p, glm46ThinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, glm46EventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = glm46ParserState_ThinkingDoneEatingWhitespace
} else {
p.state = glm46ParserState_CollectingContent
}
return events, true
} else if overlapLen := overlap(acc, glm46ThinkingCloseTag); overlapLen > 0 {
// Partial closing tag - withhold it along with any trailing whitespace before it
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
} else {
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
}
case glm46ParserState_ThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingContent)
case glm46ParserState_CollectingContent:
if strings.Contains(p.buffer.String(), glm46ToolOpenTag) {
before, after := glm46SplitAtTag(p, glm46ToolOpenTag, true)
if len(before) > 0 {
events = append(events, glm46EventContent{content: before})
}
if after == "" {
p.state = glm46ParserState_ToolStartedEatingWhitespace
} else {
p.state = glm46ParserState_CollectingToolContent
}
return events, true
} else if overlapLen := overlap(p.buffer.String(), glm46ToolOpenTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
}
case glm46ParserState_ToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingToolContent)
case glm46ParserState_CollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, glm46ToolCloseTag) {
toolContent, _ := glm46SplitAtTag(p, glm46ToolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("glm46 tool call closing tag found but no content before it")
}
events = append(events, glm46EventRawToolCall{raw: toolContent})
p.state = glm46ParserState_CollectingContent
return events, true
} else {
// Keep accumulating - tool calls are not streamed
// We just wait for the closing tag
return events, false
}
default:
panic("unreachable")
}
}
// GLMToolCallXML represents the structure of a GLM-4.6 tool call for XML parsing
type GLMToolCallXML struct {
XMLName xml.Name `xml:"tool_call"`
Content string `xml:",chardata"` // Function name (text nodes between tags)
Keys []string `xml:"arg_key"` // All arg_key elements in document order
Values []string `xml:"arg_value"` // All arg_value elements in document order
}
// escapeGLM46Content escapes XML entities in text content while preserving arg_key/arg_value tags
func escapeGLM46Content(s string) string {
var result strings.Builder
inTag := false
for i := range len(s) {
ch := s[i]
if ch == '<' {
// Check if this is a known tag
if strings.HasPrefix(s[i:], "<arg_key>") ||
strings.HasPrefix(s[i:], "</arg_key>") ||
strings.HasPrefix(s[i:], "<arg_value>") ||
strings.HasPrefix(s[i:], "</arg_value>") {
inTag = true
}
}
if inTag {
result.WriteByte(ch)
if ch == '>' {
inTag = false
}
} else {
// Escape special characters in text content
switch ch {
case '&':
result.WriteString("&amp;")
case '<':
result.WriteString("&lt;")
case '>':
result.WriteString("&gt;")
default:
result.WriteByte(ch)
}
}
}
return result.String()
}
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
// Escape any unescaped entities in text content
// We need to escape text between tags, but not the tags themselves
escaped := escapeGLM46Content(raw.raw)
// Wrap the content in a root element to make it valid XML
xmlString := "<tool_call>" + escaped + "</tool_call>"
// Parse XML into struct
var parsed GLMToolCallXML
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
}
// Extract and trim function name
functionName := strings.TrimSpace(parsed.Content)
if functionName == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
// Verify keys and values are paired correctly
if len(parsed.Keys) != len(parsed.Values) {
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionName {
matchedTool = &tools[i]
break
}
}
// Build arguments map by pairing keys and values
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: functionName,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for i := range parsed.Keys {
key := strings.TrimSpace(parsed.Keys[i])
value := parsed.Values[i] // Don't trim here - parseValue handles it
// Look up parameter type
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
// Handle anyOf by collecting all types from the union
if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf {
paramType = append(paramType, anyOfProp.Type...)
}
} else {
paramType = prop.Type
}
}
}
// Parse value with type coercion
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
}
return toolCall, nil
}

862
model/parsers/glm46_test.go Normal file
View File

@@ -0,0 +1,862 @@
package parsers
import (
"encoding/xml"
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestGLM46ParserStreaming(t *testing.T) {
type step struct {
input string
wantEvents []glm46Event
}
cases := []struct {
desc string
steps []step
only bool
}{
{
desc: "leading whitespace before think tag",
steps: []step{
{
input: " \n\t ",
wantEvents: []glm46Event{},
},
{
input: "<think>thinking</think>",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking"}},
},
},
},
{
desc: "think tag with whitespace inside",
steps: []step{
{
input: "<think> \n thinking content \n </think>regular content",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking content"},
glm46EventContent{content: "regular content"},
},
},
},
},
{
desc: "tool call with leading whitespace after opening tag",
steps: []step{
{
input: "<think></think><tool_call> \n test \n </tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "test"},
},
},
},
},
{
desc: "simple thinking then content",
steps: []step{
{
input: "<think>I am thinking</think>Now I respond",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "I am thinking"},
glm46EventContent{content: "Now I respond"},
},
},
},
},
{
desc: "streamed thinking content",
steps: []step{
{
input: "<think>hello",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "hello"}},
},
{
input: " world",
wantEvents: []glm46Event{glm46EventThinkingContent{content: " world"}},
},
{
input: "</think>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "content before tool call",
steps: []step{
{
input: "<think>Let me call a tool</think>here is text<tool_call>",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "Let me call a tool"},
glm46EventContent{content: "here is text"},
},
},
{
input: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>\n</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>"},
},
},
},
},
{
desc: "tool call with content after",
steps: []step{
{
input: "<think>thinking</think><tool_call>test</tool_call>after tool",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "after tool"},
},
},
},
},
{
desc: "trailing whitespace between content and tool call is trimmed",
steps: []step{
{
input: "<think>thinking</think>content\n \t <tool_call>test</tool_call>",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventContent{content: "content"},
glm46EventRawToolCall{raw: "test"},
},
},
},
},
{
desc: "trailing whitespace between tool call and content is trimmed",
steps: []step{
{
input: "<think>think</think><tool_call>test</tool_call>\n\t after",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "think"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "split thinking close tag",
steps: []step{
{
input: "<think>thinking content</th",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking content"}},
},
{
input: "ink>after",
wantEvents: []glm46Event{
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "split thinking open tag",
steps: []step{
{
input: " <thi",
wantEvents: []glm46Event{},
},
{
input: "nk>content</think>",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
},
},
},
{
desc: "split tool open tag",
steps: []step{
{
input: "<think>think</think>content<tool",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "think"}, glm46EventContent{content: "content"}},
},
{
input: "_call>inside",
wantEvents: []glm46Event{},
},
{
input: "</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "inside"},
},
},
},
},
{
desc: "partial thinking close tag fakeout",
steps: []step{
{
input: "<think>content</th",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
},
{
input: "ought more",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "</thought more"}},
},
},
},
{
desc: "partial thinking open tag fakeout",
steps: []step{
{
input: " <thi",
wantEvents: []glm46Event{},
},
{
input: "nking is fun",
wantEvents: []glm46Event{
glm46EventContent{content: " <thinking is fun"},
},
},
},
},
{
desc: "partial tool open tag fakeout",
steps: []step{
{
input: "<think></think>content\n<tool",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
{
input: " fakeout",
wantEvents: []glm46Event{
glm46EventContent{content: "\n<tool fakeout"},
},
},
},
},
{
desc: "partial tool close tag fakeout",
steps: []step{
{
input: "<think></think><tool_call>content</tool",
wantEvents: []glm46Event{},
},
{
input: " fakeout",
wantEvents: []glm46Event{},
},
{
input: "</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "content</tool fakeout"},
},
},
},
},
{
desc: "empty thinking tag",
steps: []step{
{
input: "<think></think>content here",
wantEvents: []glm46Event{
glm46EventContent{content: "content here"},
},
},
},
},
{
desc: "multiple tool calls in sequence",
steps: []step{
{
input: "<think>think</think><tool_call>first</tool_call>between<tool_call>second</tool_call>end",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "think"},
glm46EventRawToolCall{raw: "first"},
glm46EventContent{content: "between"},
glm46EventRawToolCall{raw: "second"},
glm46EventContent{content: "end"},
},
},
},
},
{
desc: "no thinking tag - direct to content",
steps: []step{
{
input: "just content here",
wantEvents: []glm46Event{
glm46EventContent{content: "just content here"},
},
},
},
},
{
desc: "no thinking tag - skip to content then tool call",
steps: []step{
{
input: "Here's the answer:<tool_call>test</tool_call>done",
wantEvents: []glm46Event{
glm46EventContent{content: "Here's the answer:"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "done"},
},
},
},
},
{
desc: "no thinking tag - whitespace preserved when no tags",
steps: []step{
{
input: " \n content with leading whitespace",
wantEvents: []glm46Event{
glm46EventContent{content: " \n content with leading whitespace"},
},
},
},
},
{
desc: "whitespace after think close tag gets eaten",
steps: []step{
{
input: "<think>thinking</think> \n\t content",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "whitespace after tool_call close tag gets eaten",
steps: []step{
{
input: "<think></think><tool_call>test</tool_call> \n\t content",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "thinking content withholds trailing whitespace (single chunk)",
steps: []step{
{
input: "<think>thinking content ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking content"},
},
},
{
input: "</think>after",
wantEvents: []glm46Event{
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "thinking content withholds trailing whitespace with newlines",
steps: []step{
{
input: "<think>thinking\n\n ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "</think>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "thinking content trailing whitespace emitted when more content arrives",
steps: []step{
{
input: "<think>thinking ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "more thinking",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: " more thinking"},
},
},
{
input: "</think>",
wantEvents: []glm46Event{},
},
},
},
{
desc: "thinking content withholds trailing whitespace before partial close tag",
steps: []step{
{
input: "<think>thinking </th",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "ink>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
}
anyOnlies := false
for _, tc := range cases {
if tc.only {
anyOnlies = true
}
}
for _, tc := range cases {
if anyOnlies && !tc.only {
continue
}
t.Run(tc.desc, func(t *testing.T) {
parser := GLM46Parser{}
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
gotEvents := parser.parseEvents()
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
// avoid deep equal on empty vs. nil slices
continue
}
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
}
}
})
}
}
// TestGLMToolCallXMLOrderPreservation verifies that xml.Unmarshal preserves
// document order when collecting multiple elements with the same tag name into slices.
// This is a critical assumption for the GLM-4.6 parser's struct-based approach.
func TestGLMToolCallXMLOrderPreservation(t *testing.T) {
testCases := []struct {
name string
xml string
wantKeys []string
wantValues []string
}{
{
name: "alternating keys and values",
xml: `<tool_call>
function_name
<arg_key>first</arg_key>
<arg_value>A</arg_value>
<arg_key>second</arg_key>
<arg_value>B</arg_value>
<arg_key>third</arg_key>
<arg_value>C</arg_value>
</tool_call>`,
wantKeys: []string{"first", "second", "third"},
wantValues: []string{"A", "B", "C"},
},
{
name: "all keys then all values",
xml: `<tool_call>
function_name
<arg_key>key1</arg_key>
<arg_key>key2</arg_key>
<arg_key>key3</arg_key>
<arg_value>val1</arg_value>
<arg_value>val2</arg_value>
<arg_value>val3</arg_value>
</tool_call>`,
wantKeys: []string{"key1", "key2", "key3"},
wantValues: []string{"val1", "val2", "val3"},
},
{
name: "mixed grouping",
xml: `<tool_call>
function_name
<arg_key>a</arg_key>
<arg_value>1</arg_value>
<arg_key>b</arg_key>
<arg_key>c</arg_key>
<arg_value>2</arg_value>
<arg_value>3</arg_value>
</tool_call>`,
wantKeys: []string{"a", "b", "c"},
wantValues: []string{"1", "2", "3"},
},
{
name: "reverse order - all values then all keys",
xml: `<tool_call>
function_name
<arg_value>X</arg_value>
<arg_value>Y</arg_value>
<arg_value>Z</arg_value>
<arg_key>x</arg_key>
<arg_key>y</arg_key>
<arg_key>z</arg_key>
</tool_call>`,
wantKeys: []string{"x", "y", "z"},
wantValues: []string{"X", "Y", "Z"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var parsed GLMToolCallXML
err := xml.Unmarshal([]byte(tc.xml), &parsed)
if err != nil {
t.Fatalf("failed to unmarshal XML: %v", err)
}
if !reflect.DeepEqual(parsed.Keys, tc.wantKeys) {
t.Errorf("Keys order mismatch:\ngot: %v\nwant: %v", parsed.Keys, tc.wantKeys)
}
if !reflect.DeepEqual(parsed.Values, tc.wantValues) {
t.Errorf("Values order mismatch:\ngot: %v\nwant: %v", parsed.Values, tc.wantValues)
}
})
}
}
func TestGLM46ToolCallParsing(t *testing.T) {
type testCase struct {
name string
rawToolCall string
tools []api.Tool
wantToolCall api.ToolCall
}
cases := []testCase{
{
name: "simple tool call",
tools: []api.Tool{},
rawToolCall: `get-current-weather
<arg_key>location</arg_key>
<arg_value>New York, NY</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-current-weather",
Arguments: args(`{"location": "New York, NY", "unit": "celsius"}`),
},
},
},
{
name: "tool call with typed parameters",
tools: []api.Tool{
tool("calculate", map[string]api.ToolProperty{
"x": {Type: api.PropertyType{"number"}},
"y": {Type: api.PropertyType{"integer"}},
"enabled": {Type: api.PropertyType{"boolean"}},
"items": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `calculate
<arg_key>x</arg_key>
<arg_value>3.14</arg_value>
<arg_key>y</arg_key>
<arg_value>42</arg_value>
<arg_key>enabled</arg_key>
<arg_value>true</arg_value>
<arg_key>items</arg_key>
<arg_value>["a", "b", "c"]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: args(`{"enabled": true, "items": ["a", "b", "c"], "x": 3.14, "y": 42}`),
},
},
},
{
name: "function name with whitespace",
tools: []api.Tool{},
rawToolCall: ` get-weather
<arg_key>city</arg_key>
<arg_value>Paris</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-weather",
Arguments: args(`{"city": "Paris"}`),
},
},
},
{
name: "values with special characters",
tools: []api.Tool{},
rawToolCall: `execute-command
<arg_key>command</arg_key>
<arg_value>ls && echo "done"</arg_value>
<arg_key>message</arg_key>
<arg_value>a < b and c > d</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "execute-command",
Arguments: args(`{"command": "ls && echo \"done\"", "message": "a < b and c > d"}`),
},
},
},
{
name: "unicode in function names and values",
tools: []api.Tool{},
rawToolCall: `获取天气
<arg_key>城市</arg_key>
<arg_value>北京</arg_value>
<arg_key>message</arg_key>
<arg_value>Hello! 你好! 🌟</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: args(`{"message": "Hello! 你好! 🌟", "城市": "北京"}`),
},
},
},
{
name: "empty value",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param1</arg_key>
<arg_value></arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param1": ""}`),
},
},
},
{
name: "special chars in arg_key names",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param<1></arg_key>
<arg_value>value1</arg_value>
<arg_key>a&b</arg_key>
<arg_value>value2</arg_value>
<arg_key>x>y</arg_key>
<arg_value>value3</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"a&b": "value2", "param<1>": "value1", "x>y": "value3"}`),
},
},
},
{
name: "multiple consecutive ampersands",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value>test &&&& more</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param": "test &&&& more"}`),
},
},
},
{
name: "mixed special chars together",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value><>&<>&</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param": "<>&<>&"}`),
},
},
},
{
name: "newlines and tabs in parameter values",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>multiline</arg_key>
<arg_value>line1
indented line2
line3</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"multiline": "line1\n\tindented line2\nline3"}`),
},
},
},
{
name: "single and double quotes in values",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>quotes</arg_key>
<arg_value>She said "Hello's there!"</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"quotes": "She said \"Hello's there!\""}`),
},
},
},
{
name: "CDATA-like content that should be treated as text",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>cdata</arg_key>
<arg_value><![CDATA[not actual cdata]]></arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"cdata": "<![CDATA[not actual cdata]]>"}`),
},
},
},
{
name: "all special XML entities",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>entities</arg_key>
<arg_value>&lt;&gt;&amp;&apos;&quot;</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"entities": "&lt;&gt;&amp;&apos;&quot;"}`),
},
},
},
{
name: "order preservation with multiple parameters",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>first</arg_key>
<arg_value>value1</arg_value>
<arg_key>second</arg_key>
<arg_value>value2</arg_value>
<arg_key>third</arg_key>
<arg_value>value3</arg_value>
<arg_key>fourth</arg_key>
<arg_value>value4</arg_value>
<arg_key>fifth</arg_key>
<arg_value>value5</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"fifth": "value5", "first": "value1", "fourth": "value4", "second": "value2", "third": "value3"}`),
},
},
},
{
name: "order preservation with identical key names but different positions",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value>first occurrence</arg_value>
<arg_key>other</arg_key>
<arg_value>middle</arg_value>
<arg_key>param</arg_key>
<arg_value>second occurrence</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
// Later occurrence should overwrite earlier one
Arguments: args(`{"other": "middle", "param": "second occurrence"}`),
},
},
},
{
name: "array with mixed types",
tools: []api.Tool{
tool("process", map[string]api.ToolProperty{
"items": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `process
<arg_key>items</arg_key>
<arg_value>[1, "hello", true, null]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process",
Arguments: args(`{"items": [1, "hello", true, null]}`),
},
},
},
{
name: "empty array",
tools: []api.Tool{
tool("test", map[string]api.ToolProperty{
"tags": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `test
<arg_key>tags</arg_key>
<arg_value>[]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test",
Arguments: args(`{"tags": []}`),
},
},
},
{
name: "anyOf array or string - with array of objects",
tools: []api.Tool{
tool("TodoWrite", map[string]api.ToolProperty{
"todos": {AnyOf: []api.ToolProperty{{Type: api.PropertyType{"array"}}, {Type: api.PropertyType{"string"}}}},
}),
},
// <tool_call>TodoWrite
// <arg_key>todos</arg_key>
// <arg_value>[{"content": "Set up HTML file and basic structure", "id": "1", "priority": "high", "status": "pending"}, {"content": "Create 3D scene with Three.js", "id": "2", "priority": "high", "status": "pending"}, {"content": "Implement terrain generation with blocks", "id": "3", "priority": "high", "status": "pending"}, {"content": "Add player controls (movement, camera)", "id": "4", "priority": "high", "status": "pending"}, {"content": "Implement block placement/destruction", "id": "5", "priority": "medium", "status": "pending"}, {"content": "Add lighting and textures", "id": "6", "priority": "medium", "status": "pending"}, {"content": "Test and optimize performance", "id": "7", "priority": "low", "status": "pending"}]</arg_value>
// </tool_call>
rawToolCall: `TodoWrite
<arg_key>todos</arg_key>
<arg_value>[{"content": "task 1", "status": "pending", "priority": "high", "id": "1"}, {"content": "task 2", "status": "completed", "priority": "low", "id": "2"}]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "TodoWrite",
Arguments: args(`{"todos": [{"content": "task 1", "id": "1", "priority": "high", "status": "pending"}, {"content": "task 2", "id": "2", "priority": "low", "status": "completed"}]}`),
},
},
},
{
name: "anyOf array or string - with plain string",
tools: []api.Tool{
tool("TodoWrite", map[string]api.ToolProperty{
"todos": {Type: api.PropertyType{"array", "string"}},
}),
},
rawToolCall: `TodoWrite
<arg_key>todos</arg_key>
<arg_value>Error: could not load todos</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "TodoWrite",
Arguments: args(`{"todos": "Error: could not load todos"}`),
},
},
},
}
for i, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
gotToolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: tc.rawToolCall}, tc.tools)
if err != nil {
t.Errorf("case %d (%s): %v", i, tc.name, err)
}
if !toolCallEqual(gotToolCall, tc.wantToolCall) {
t.Errorf("case %d (%s): got tool call %#v, want %#v", i, tc.name, gotToolCall, tc.wantToolCall)
}
})
}
}

20
model/parsers/glm47.go Normal file
View File

@@ -0,0 +1,20 @@
package parsers
import "github.com/ollama/ollama/api"
// GLM47Parser extends GLM46Parser with thinking-aware initialization.
// GLM-4.7's prompt ends with <think> when thinking is enabled, so the parser
// must start in CollectingThinking state (the model outputs thinking content directly).
type GLM47Parser struct {
GLM46Parser
}
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
// When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() {
p.state = glm46ParserState_CollectingThinking
}
return tools
}

View File

@@ -0,0 +1,99 @@
package parsers
import (
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestGLM47ParserAdd(t *testing.T) {
parser := GLM47Parser{}
parser.Init([]api.Tool{
tool("calculate", map[string]api.ToolProperty{
"count": {Type: api.PropertyType{"integer"}},
"enabled": {Type: api.PropertyType{"boolean"}},
}),
}, nil, nil)
// When thinking is enabled (thinkValue nil), the prompt ends with <think>,
// so the model output does NOT include the opening <think> tag.
content, thinking, calls, err := parser.Add("plan</think>Answer<tool_call>calculate<arg_key>count</arg_key><arg_value>3</arg_value><arg_key>enabled</arg_key><arg_value>true</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "plan" {
t.Fatalf("expected thinking 'plan', got %q", thinking)
}
if content != "Answer" {
t.Fatalf("expected content 'Answer', got %q", content)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
expectedArgs := args(`{"count": 3, "enabled": true}`)
if !toolCallEqual(api.ToolCall{Function: api.ToolCallFunction{Arguments: calls[0].Function.Arguments}}, api.ToolCall{Function: api.ToolCallFunction{Arguments: expectedArgs}}) {
t.Fatalf("expected args %#v, got %#v", expectedArgs.ToMap(), calls[0].Function.Arguments.ToMap())
}
}
func TestGLM47ParserNoThinkingContent(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
// When thinking is enabled but model has no thinking to output,
// it should output </think> immediately followed by content.
content, thinking, calls, err := parser.Add("</think>Plain answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if content != "Plain answer" {
t.Fatalf("expected content 'Plain answer', got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestGLM47ParserThinkingDisabled(t *testing.T) {
parser := GLM47Parser{}
// When thinking is disabled, parser stays in LookingForThinkingOpen state
parser.Init(nil, nil, &api.ThinkValue{Value: false})
// Model outputs plain content (prompt ended with </think>)
content, thinking, calls, err := parser.Add("Plain answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if content != "Plain answer" {
t.Fatalf("expected content 'Plain answer', got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestGLM47ParserToolCallEscaping(t *testing.T) {
toolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: `exec
<arg_key>expr</arg_key>
<arg_value>a < b && c > d</arg_value>`}, nil)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
expected := api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: args(`{"expr": "a < b && c > d"}`),
},
}
if !reflect.DeepEqual(toolCall, expected) {
t.Fatalf("expected %#v, got %#v", expected, toolCall)
}
}

498
model/parsers/lfm2.go Normal file
View File

@@ -0,0 +1,498 @@
package parsers
import (
"encoding/json"
"errors"
"log/slog"
"strconv"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
type LFM2ParserState int
const (
LFM2CollectingThinking LFM2ParserState = iota
LFM2CollectingContent
LFM2CollectingToolCalls
)
const (
lfm2ThinkingOpenTag = "<think>"
lfm2ThinkingCloseTag = "</think>"
lfm2ToolCallStartTag = "<|tool_call_start|>"
lfm2ToolCallEndTag = "<|tool_call_end|>"
)
type LFM2Parser struct {
state LFM2ParserState
buffer strings.Builder
hasThinkingSupport bool
needsThinkingLeadingTrim bool // trim leading whitespace after <think> tag
needsContentLeadingTrim bool // trim leading whitespace after </think> tag
}
func (p *LFM2Parser) HasToolSupport() bool {
return true
}
func (p *LFM2Parser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.ThinkValue) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
// Check both model capability AND request preference
thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
if !thinkingEnabled {
p.state = LFM2CollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = LFM2CollectingContent
return
}
p.state = LFM2CollectingThinking
p.needsThinkingLeadingTrim = true
}
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.setInitialState(lastMessage, thinkValue)
return tools
}
type lfm2Event interface {
isLFM2Event()
}
type lfm2EventThinkingContent struct {
content string
}
type lfm2EventContent struct {
content string
}
type lfm2EventToolCall struct {
toolCall api.ToolCall
}
func (lfm2EventThinkingContent) isLFM2Event() {}
func (lfm2EventContent) isLFM2Event() {}
func (lfm2EventToolCall) isLFM2Event() {}
func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case lfm2EventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case lfm2EventThinkingContent:
thinkingSb.WriteString(event.content)
case lfm2EventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *LFM2Parser) parseEvents() []lfm2Event {
var all []lfm2Event
keepLooping := true
for keepLooping {
var events []lfm2Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
var events []lfm2Event
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case LFM2CollectingThinking:
// Strip opening <think> tag if present
if strings.HasPrefix(bufStr, lfm2ThinkingOpenTag) {
bufStr = bufStr[len(lfm2ThinkingOpenTag):]
p.needsThinkingLeadingTrim = true
p.buffer.Reset()
p.buffer.WriteString(bufStr)
}
// Trim leading whitespace after <think> tag (may span multiple chunks)
if p.needsThinkingLeadingTrim {
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
bufStr = trimmed
p.buffer.Reset()
p.buffer.WriteString(bufStr)
}
// Clear flag once we have non-whitespace content or buffer is empty
if len(bufStr) > 0 {
p.needsThinkingLeadingTrim = false
}
}
if strings.Contains(bufStr, lfm2ThinkingCloseTag) { // thinking[</think>] -> content
split := strings.SplitN(bufStr, lfm2ThinkingCloseTag, 2)
thinking := split[0]
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
remaining := split[1]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = LFM2CollectingContent
p.needsThinkingLeadingTrim = false
// Set flag to trim any additional whitespace that may arrive in later chunks
p.needsContentLeadingTrim = len(remaining) == 0
if len(thinking) > 0 {
events = append(events, lfm2EventThinkingContent{content: thinking})
}
return events, true
} else if overlapLen := overlap(bufStr, lfm2ThinkingCloseTag); overlapLen > 0 { // partial </think>
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, lfm2EventThinkingContent{content: unambiguous})
}
return events, false
} else { // otherwise its thinking content
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, lfm2EventThinkingContent{content: unambiguous})
}
return events, false
}
case LFM2CollectingContent:
// Trim leading whitespace after </think> tag (may span multiple chunks)
if p.needsContentLeadingTrim {
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
bufStr = trimmed
p.buffer.Reset()
p.buffer.WriteString(bufStr)
}
// Clear flag once we have non-whitespace content
if len(bufStr) > 0 {
p.needsContentLeadingTrim = false
}
}
if strings.Contains(bufStr, lfm2ToolCallStartTag) { // content[<|tool_call_start|>] -> tool calls
split := strings.SplitN(bufStr, lfm2ToolCallStartTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = LFM2CollectingToolCalls
if len(contentBefore) > 0 {
events = append(events, lfm2EventContent{content: contentBefore})
}
return events, true
} else { // otherwise its content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, lfm2EventContent{content: bufStr})
}
return events, false
}
case LFM2CollectingToolCalls:
// Look for complete tool call JSON between tags
if idx := strings.Index(bufStr, lfm2ToolCallEndTag); idx != -1 {
toolCallContent := bufStr[:idx]
if toolCalls, err := p.parseToolCallsContent(toolCallContent); err == nil && len(toolCalls) > 0 {
remaining := bufStr[idx+len(lfm2ToolCallEndTag):]
// Check if there's another tool call
if strings.HasPrefix(remaining, lfm2ToolCallStartTag) {
remaining = remaining[len(lfm2ToolCallStartTag):]
} else {
// No more tool calls, go back to content
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.state = LFM2CollectingContent
}
p.buffer.Reset()
p.buffer.WriteString(remaining)
for _, tc := range toolCalls {
events = append(events, lfm2EventToolCall{toolCall: tc})
}
return events, true
} else if err != nil {
slog.Warn("lfm2 tool call parsing failed", "error", err, "content", toolCallContent)
}
}
return events, false
}
return events, false
}
// parseToolCallsContent parses one or more tool calls from content
// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)]
func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) {
content = strings.TrimSpace(content)
// Try JSON format first: {"name": "func", "arguments": {...}}
var parsed struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
var args api.ToolCallFunctionArguments
if len(parsed.Arguments) > 0 {
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
return nil, err
}
} else {
args = api.NewToolCallFunctionArguments()
}
return []api.ToolCall{{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: args,
},
}}, nil
}
// Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
return p.parsePythonStyleToolCalls(content)
}
// parsePythonStyleToolCalls parses one or more Python-style tool calls
// Examples: [bash(command='ls'),bash(command='pwd')] or bash(command='ls')
func (p *LFM2Parser) parsePythonStyleToolCalls(content string) ([]api.ToolCall, error) {
content = strings.TrimSpace(content)
// Strip outer brackets if present: [func(...)] -> func(...)
if strings.HasPrefix(content, "[") && strings.HasSuffix(content, "]") {
content = content[1 : len(content)-1]
}
var toolCalls []api.ToolCall
// Parse multiple function calls separated by commas at the top level
for len(content) > 0 {
content = strings.TrimSpace(content)
if content == "" {
break
}
// Skip leading comma from previous iteration
if strings.HasPrefix(content, ",") {
content = strings.TrimSpace(content[1:])
if content == "" {
break
}
}
// Find function name
parenIdx := strings.Index(content, "(")
if parenIdx == -1 {
return nil, errors.New("invalid tool call: no opening parenthesis")
}
funcName := strings.TrimSpace(content[:parenIdx])
if funcName == "" {
return nil, errors.New("invalid tool call: empty function name")
}
// Find matching closing parenthesis
closeIdx := findMatchingParen(content, parenIdx)
if closeIdx == -1 {
return nil, errors.New("invalid tool call: no matching closing parenthesis")
}
argsStr := content[parenIdx+1 : closeIdx]
args := api.NewToolCallFunctionArguments()
if argsStr != "" {
if err := parsePythonArgs(argsStr, &args); err != nil {
return nil, err
}
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: funcName,
Arguments: args,
},
})
// Move past this function call
content = content[closeIdx+1:]
}
if len(toolCalls) == 0 {
return nil, errors.New("no tool calls found")
}
return toolCalls, nil
}
// findMatchingParen finds the index of the closing parenthesis matching the one at openIdx
// Returns -1 if not found. Handles nested parentheses and quoted strings.
func findMatchingParen(s string, openIdx int) int {
depth := 1
i := openIdx + 1
for i < len(s) && depth > 0 {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 {
return i
}
case '\'', '"':
// Skip quoted string
quote := s[i]
i++
for i < len(s) && s[i] != quote {
if s[i] == '\\' && i+1 < len(s) {
i++ // skip escaped char
}
i++
}
}
i++
}
return -1
}
// parseToolCallContent parses a single tool call (for backward compatibility with tests)
func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error) {
calls, err := p.parseToolCallsContent(content)
if err != nil {
return api.ToolCall{}, err
}
if len(calls) == 0 {
return api.ToolCall{}, errors.New("no tool call found")
}
return calls[0], nil
}
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
// Simple state machine to parse key='value' pairs
// Handles: command='ls', flag="-la", count=42, enabled=true
var key string
i := 0
for i < len(argsStr) {
// Skip whitespace
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
i++
}
if i >= len(argsStr) {
break
}
// Parse key
keyStart := i
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
i++
}
if i >= len(argsStr) || argsStr[i] != '=' {
return errors.New("invalid argument: expected '='")
}
key = strings.TrimSpace(argsStr[keyStart:i])
i++ // skip '='
// Skip whitespace after =
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
i++
}
// Parse value
var value string
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
// Quoted string
quote := argsStr[i]
i++
valueStart := i
for i < len(argsStr) && argsStr[i] != quote {
if argsStr[i] == '\\' && i+1 < len(argsStr) {
i += 2 // skip escaped char
} else {
i++
}
}
value = argsStr[valueStart:i]
if i < len(argsStr) {
i++ // skip closing quote
}
args.Set(key, value)
} else {
// Unquoted value (number, bool, etc)
valueStart := i
for i < len(argsStr) && argsStr[i] != ',' {
i++
}
value = strings.TrimSpace(argsStr[valueStart:i])
// Try to parse as number or bool
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
args.Set(key, v)
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
args.Set(key, v)
} else if value == "true" {
args.Set(key, true)
} else if value == "false" {
args.Set(key, false)
} else {
args.Set(key, value)
}
}
// Skip comma and whitespace
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
i++
}
}
return nil
}

1088
model/parsers/lfm2_test.go Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -68,6 +68,12 @@ func ParserForName(name string) Parser {
return &Nemotron3NanoParser{}
case "functiongemma":
return &FunctionGemmaParser{}
case "glm-4.7":
return &GLM47Parser{}
case "lfm2":
return &LFM2Parser{hasThinkingSupport: false}
case "lfm2-thinking":
return &LFM2Parser{hasThinkingSupport: true}
default:
return nil
}

View File

@@ -96,3 +96,11 @@ func testArgs(m map[string]any) api.ToolCallFunctionArguments {
}
return args
}
func args(s string) api.ToolCallFunctionArguments {
var result api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(s), &result); err != nil {
panic("invalid JSON in args(): " + err.Error())
}
return result
}

110
model/renderers/glm46.go Normal file
View File

@@ -0,0 +1,110 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
type GLM46Renderer struct{}
func (r *GLM46Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
var lastUserIndex int
for i, message := range messages {
if message.Role == "user" {
lastUserIndex = i
}
}
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(string(d) + "\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}\n")
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
sb.WriteString("...\n")
sb.WriteString("</tool_call>")
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>\n")
sb.WriteString(message.Content)
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
sb.WriteString("/nothink")
}
case "assistant":
sb.WriteString("<|assistant|>")
if i > lastUserIndex {
if message.Thinking != "" {
sb.WriteString("\n<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("\n<think></think>")
}
}
if message.Content != "" {
sb.WriteString("\n" + message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
for key, value := range toolCall.Function.Arguments.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
}
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>")
case "system":
sb.WriteString("<|system|>\n")
sb.WriteString(message.Content)
}
}
// Add generation prompt
sb.WriteString("<|assistant|>")
if thinkValue != nil && !thinkValue.Bool() {
sb.WriteString("\n<think></think>\n")
}
return sb.String(), nil
}

View File

@@ -0,0 +1,223 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestGLM46Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
skip string
}{
{
name: "basic",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|system|>
You are a helpful assistant.<|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with user assistant user",
messages: []api.Message{
{Role: "user", Content: "What is the capital of France?"},
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
{Role: "user", Content: "Fantastic!"},
},
expected: `[gMASK]<sop><|user|>
What is the capital of France?<|assistant|>
The capital of France is Paris.<|user|>
Fantastic!<|assistant|>`,
},
{
skip: "tool call ordering not guaranteed yet",
name: "tools",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>`,
},
{
skip: "tool call ordering not guaranteed yet",
name: "tool calls",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo, Japan", "unit": "celsius"}`),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Japan", "unit": "fahrenheit"}`),
},
},
},
},
{
Role: "tool",
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
ToolName: "get_weather",
},
{
Role: "tool",
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
ToolName: "get_weather",
},
{
Role: "assistant",
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>
<think></think>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Tokyo, Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>
</tool_call>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call><|observation|>
<tool_response>
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
</tool_response>
<tool_response>
{"temperature": 68, "weather": "sunny", "humidity": 75}
</tool_response><|assistant|>
<think></think>
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
},
{
name: "think true",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "think false",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `[gMASK]<sop><|user|>
Hello, how are you?/nothink<|assistant|>
<think></think>
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skip != "" {
t.Skip(tt.skip)
}
renderer := &GLM46Renderer{}
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
t.Logf("Got:\n%s", rendered)
t.Logf("Expected:\n%s", tt.expected)
}
})
}
}

170
model/renderers/glm47.go Normal file
View File

@@ -0,0 +1,170 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
// GLM47Renderer renders messages for GLM-4.7 models.
//
// GLM-4.7 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
//
// 1. INTERLEAVED THINKING
// The model thinks between tool calls and after receiving tool results.
// This enables complex step-by-step reasoning: interpreting each tool output
// before deciding what to do next. Thinking blocks are preserved and returned
// with tool results to maintain reasoning continuity.
//
// 2. PRESERVED THINKING
// The model retains reasoning content from previous assistant turns in context.
// This preserves reasoning continuity across multi-turn conversations. The
// upstream API has a "clear_thinking" parameter to control this:
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
//
// 3. TURN-LEVEL THINKING
// Controls whether the model should reason on each turn. The upstream API
// uses "enable_thinking" parameter:
// - enable_thinking=true: outputs <think> to start reasoning
// - enable_thinking=false: outputs </think> to skip reasoning
//
// OLLAMA DEFAULTS:
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
// - Users can disable thinking per-turn via thinkValue=false
type GLM47Renderer struct{}
func (r *GLM47Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatGLM47ToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
think := true
if thinkValue != nil && !thinkValue.Bool() {
think = false
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>")
sb.WriteString(message.Content)
case "assistant":
sb.WriteString("<|assistant|>")
if message.Thinking != "" {
sb.WriteString("<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("</think>")
}
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderGLM47ToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("<tool_response>")
sb.WriteString(message.Content)
sb.WriteString("</tool_response>")
case "system":
sb.WriteString("<|system|>")
sb.WriteString(message.Content)
}
}
sb.WriteString("<|assistant|>")
if think {
sb.WriteString("<think>")
} else {
sb.WriteString("</think>")
}
return sb.String(), nil
}
func renderGLM47ToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}
func formatGLM47ToolJSON(raw []byte) string {
var sb strings.Builder
sb.Grow(len(raw) + len(raw)/10)
inString := false
escaped := false
for i := range raw {
ch := raw[i]
sb.WriteByte(ch)
if inString {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
continue
}
if ch == ':' || ch == ',' {
sb.WriteByte(' ')
}
}
return sb.String()
}

View File

@@ -0,0 +1,191 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestGLM47Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic user message",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
expected: "[gMASK]<sop><|user|>Hello<|assistant|><think>",
},
{
name: "thinking disabled",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "[gMASK]<sop><|user|>Hello<|assistant|></think>",
},
{
name: "system and user",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
expected: "[gMASK]<sop><|system|>You are helpful.<|user|>Hello<|assistant|><think>",
},
{
name: "multi-turn conversation",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello there"},
{Role: "user", Content: "How are you?"},
},
expected: "[gMASK]<sop><|user|>Hi<|assistant|></think>Hello there<|user|>How are you?<|assistant|><think>",
},
{
name: "assistant with reasoning_content",
messages: []api.Message{
{Role: "user", Content: "Answer with reasoning."},
{Role: "assistant", Thinking: "Plan.", Content: "Done."},
},
expected: "[gMASK]<sop><|user|>Answer with reasoning.<|assistant|><think>Plan.</think>Done.<|assistant|><think>",
},
{
name: "tool call with empty content",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo", "unit": "celsius"}`),
},
},
},
},
{Role: "tool", Content: `{"temperature":22}`},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string"}}`),
},
},
},
},
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|><think>",
},
{
name: "tool call with content",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
Content: "Let me check",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo"}`),
},
},
},
},
{Role: "tool", Content: `{"temperature":22}`},
{Role: "assistant", Content: "It is 22C."},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string"}}`),
},
},
},
},
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think>Let me check<tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|></think>It is 22C.<|assistant|><think>",
},
{
name: "multiple tool calls and responses",
messages: []api.Message{
{Role: "user", Content: "Compare weather"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo"}`),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Paris"}`),
},
},
},
},
{Role: "tool", Content: `{"temperature":22}`},
{Role: "tool", Content: `{"temperature":18}`},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string"}}`),
},
},
},
},
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Compare weather<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><tool_call>get_weather<arg_key>location</arg_key><arg_value>Paris</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><tool_response>{\"temperature\":18}</tool_response><|assistant|><think>",
},
{
name: "preserved thinking in multi-turn",
messages: []api.Message{
{Role: "user", Content: "Think step by step"},
{Role: "assistant", Thinking: "Let me think...", Content: "Here's my answer."},
{Role: "user", Content: "Continue"},
},
expected: "[gMASK]<sop><|user|>Think step by step<|assistant|><think>Let me think...</think>Here's my answer.<|user|>Continue<|assistant|><think>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
renderer := &GLM47Renderer{}
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
t.Logf("Got:\n%s", rendered)
t.Logf("Expected:\n%s", tt.expected)
}
})
}
}

144
model/renderers/lfm2.go Normal file
View File

@@ -0,0 +1,144 @@
package renderers
import (
"encoding/json"
"strings"
"github.com/ollama/ollama/api"
)
type LFM2Renderer struct {
IsThinking bool
}
func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
// Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
// Extract first system message if present (to combine with tools)
var firstSystemContent string
startIdx := 0
if len(messages) > 0 && messages[0].Role == "system" {
firstSystemContent = messages[0].Content
startIdx = 1
}
// Append tools to first system content
if len(tools) > 0 {
if firstSystemContent != "" {
firstSystemContent += "\n"
}
firstSystemContent += "List of tools: ["
for i, tool := range tools {
toolJSON, err := json.Marshal(tool)
if err != nil {
return "", err
}
firstSystemContent += string(toolJSON)
if i < len(tools)-1 {
firstSystemContent += ", "
}
}
firstSystemContent += "]"
}
// Output first system block if it has content
if firstSystemContent != "" {
sb.WriteString("<|im_start|>system\n")
sb.WriteString(firstSystemContent)
sb.WriteString("<|im_end|>\n")
}
// Find the index of the last assistant message for thinking stripping
lastAssistantIndex := -1
for i := len(messages) - 1; i >= startIdx; i-- {
if messages[i].Role == "assistant" {
lastAssistantIndex = i
break
}
}
// Track whether we need to add generation prompt
needsGenerationPrompt := len(messages) > 0
for i := startIdx; i < len(messages); i++ {
message := messages[i]
switch message.Role {
case "system":
// Additional system messages (after the first) are rendered normally
sb.WriteString("<|im_start|>system\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
case "user":
sb.WriteString("<|im_start|>user\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
case "assistant":
sb.WriteString("<|im_start|>assistant\n")
// Check if this is the last assistant message
isLastAssistant := i == lastAssistantIndex
// Process content (may need thinking stripped)
content := message.Content
// Handle thinking tags in assistant content
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
if strings.Contains(content, "</think>") {
parts := strings.SplitN(content, "</think>", 2)
if len(parts) > 1 {
if !isLastAssistant && !keepPastThinking {
// Strip thinking entirely for past assistant messages
content = strings.TrimSpace(parts[1])
} else {
// Preserve thinking but trim whitespace after </think>
content = parts[0] + "</think>" + strings.TrimLeft(parts[1], " \t\n\r")
}
}
}
if len(message.ToolCalls) > 0 {
// Assistant with tool calls - write content first (if any after stripping)
if content != "" {
sb.WriteString(content)
}
for _, toolCall := range message.ToolCalls {
sb.WriteString("<|tool_call_start|>")
toolCallJSON := map[string]any{
"name": toolCall.Function.Name,
"arguments": toolCall.Function.Arguments,
}
callJSON, _ := json.Marshal(toolCallJSON)
sb.WriteString(string(callJSON))
sb.WriteString("<|tool_call_end|>")
}
} else {
sb.WriteString(content)
}
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
case "tool":
// Tool responses are rendered as plain messages per the chat template
sb.WriteString("<|im_start|>tool\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
}
}
// Add generation prompt
if needsGenerationPrompt {
sb.WriteString("<|im_start|>assistant\n")
// Note: Model is a "thinking-only" model - it will output <think> itself
// We don't add <think> tag to the prompt
}
return sb.String(), nil
}

View File

@@ -0,0 +1,427 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestLFM2Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic user message",
messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple system messages rendered separately",
messages: []api.Message{
{Role: "system", Content: "First instruction."},
{Role: "system", Content: "Second instruction."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multi-turn conversation",
messages: []api.Message{
{Role: "user", Content: "What is 2+2?"},
{Role: "assistant", Content: "The answer is 4."},
{Role: "user", Content: "Thanks!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "only system message",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
},
{
// When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
name: "user-assistant-user: last assistant preserves thinking",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reasoning</think>A1"},
{Role: "user", Content: "Q2"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reasoning</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
},
{
// With two assistants, first is stripped (not last), second preserved (is last)
name: "multi-turn thinking: first stripped, second preserved",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
{Role: "user", Content: "Q2"},
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
},
{
// With thinking enabled (keep_past_thinking=true), both preserved
name: "multi-turn thinking: both preserved when thinking enabled",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
{Role: "user", Content: "Q2"},
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "assistant with tool calls",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "assistant with content and tool calls",
messages: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: "Let me check.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tool response",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "assistant", Content: "Let me check."},
{Role: "tool", Content: "22C, Sunny"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple tool calls",
messages: []api.Message{
{Role: "user", Content: "Get weather for Paris and London"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "London",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions with system message",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions without system message",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "multiple tools without system message",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "get_time",
Description: "Get time",
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "user-tool sequence",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "tool", Content: "22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "full tool call cycle",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "assistant", Content: "Let me check"},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "It's 22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "unicode content",
messages: []api.Message{
{Role: "user", Content: "你好世界! مرحبا 🌍"},
{Role: "assistant", Content: "Hello! 👋"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "newlines in content",
messages: []api.Message{
{Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
{Role: "assistant", Content: "Response with\nmultiple\nlines"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "empty assistant content",
messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: ""},
{Role: "user", Content: "OK"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
},
{
// Generation prompt does NOT include <think> - model outputs it
name: "generation prompt has no think tag",
messages: []api.Message{
{Role: "user", Content: "Think hard"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
},
{
// Interleaved: thinking before tool call - last assistant preserves thinking
name: "thinking before tool call (last assistant)",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>I need to check the weather</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>I need to check the weather</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tool calls - first has thinking stripped
name: "two assistants with tools: first thinking stripped",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tools - both preserved when thinking enabled
name: "two assistants with tools: both preserved when thinking enabled",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>checking</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Content before thinking before tool call
name: "content then thinking then tool call",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "Let me check.<think>Using weather API</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<think>Using weather API</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
},
}
renderer := &LFM2Renderer{IsThinking: true}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -80,6 +80,12 @@ func rendererForName(name string) Renderer {
return &Nemotron3NanoRenderer{}
case "functiongemma":
return &FunctionGemmaRenderer{}
case "glm-4.7":
return &GLM47Renderer{}
case "lfm2":
return &LFM2Renderer{IsThinking: false}
case "lfm2-thinking":
return &LFM2Renderer{IsThinking: true}
default:
return nil
}

View File

@@ -1,6 +1,26 @@
package renderers
import "github.com/ollama/ollama/api"
import (
"encoding/json"
"github.com/ollama/ollama/api"
)
func args(s string) api.ToolCallFunctionArguments {
var result api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(s), &result); err != nil {
panic("invalid JSON in args(): " + err.Error())
}
return result
}
func propsMap(s string) *api.ToolPropertiesMap {
var result api.ToolPropertiesMap
if err := json.Unmarshal([]byte(s), &result); err != nil {
panic("invalid JSON in propsMap(): " + err.Error())
}
return &result
}
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {

View File

@@ -28,6 +28,7 @@ import (
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@@ -90,7 +91,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- resp
}
oldManifest, _ := ParseNamedManifest(name)
oldManifest, _ := manifest.ParseNamedManifest(name)
var baseLayers []*layerGGML
var err error
@@ -123,9 +124,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
}
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
manifest, mErr := ParseNamedManifest(fromName)
if mErr == nil && manifest.Config.Digest != "" {
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
mf, mErr := manifest.ParseNamedManifest(fromName)
if mErr == nil && mf.Config.Digest != "" {
configPath, pErr := manifest.GetBlobsPath(mf.Config.Digest)
if pErr == nil {
if cfgFile, fErr := os.Open(configPath); fErr == nil {
var baseConfig model.ConfigV2
@@ -342,7 +343,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
return "gguf"
} else {
// try to see if we can find a gguf file even without the file extension
blobPath, err := GetBlobsPath(files[fn])
blobPath, err := manifest.GetBlobsPath(files[fn])
if err != nil {
slog.Error("error getting blobs path", "file", fn)
return ""
@@ -394,7 +395,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
}
blobPath, err := GetBlobsPath(digest)
blobPath, err := manifest.GetBlobsPath(digest)
if err != nil {
return nil, err
}
@@ -432,7 +433,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, err
}
layer, err := NewLayer(t, mediaType)
layer, err := manifest.NewLayer(t, mediaType)
if err != nil {
return nil, err
}
@@ -465,7 +466,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
}
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
var layers []Layer
var layers []manifest.Layer
for _, layer := range baseLayers {
if layer.GGML != nil {
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
@@ -550,13 +551,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
}
for _, layer := range layers {
if layer.status != "" {
fn(api.ProgressResponse{Status: layer.status})
if layer.Status != "" {
fn(api.ProgressResponse{Status: layer.Status})
}
}
fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, *configLayer, layers); err != nil {
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
return err
}
@@ -577,7 +578,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err
}
blob, err := GetBlobsPath(layer.Digest)
blob, err := manifest.GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -599,7 +600,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
}
temp.Seek(0, io.SeekStart)
fn(api.ProgressResponse{Status: "verifying conversion"})
newLayer, err := NewLayer(temp, layer.MediaType)
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
if err != nil {
return nil, err
}
@@ -619,7 +620,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
var layers []*layerGGML
fn(api.ProgressResponse{Status: "parsing GGUF"})
blobPath, err := GetBlobsPath(digest)
blobPath, err := manifest.GetBlobsPath(digest)
if err != nil {
return nil, err
}
@@ -654,7 +655,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
mediatype = "application/vnd.ollama.image.projector"
}
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
@@ -665,8 +666,8 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
return detectChatTemplate(layers)
}
func removeLayer(layers []Layer, mediatype string) []Layer {
return slices.DeleteFunc(layers, func(layer Layer) bool {
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
if layer.MediaType != mediatype {
return false
}
@@ -680,7 +681,7 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
})
}
func setTemplate(layers []Layer, t string) ([]Layer, error) {
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.template")
if _, err := template.Parse(t); err != nil {
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
@@ -690,7 +691,7 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
}
blob := strings.NewReader(t)
layer, err := NewLayer(blob, "application/vnd.ollama.image.template")
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
@@ -699,11 +700,11 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
return layers, nil
}
func setSystem(layers []Layer, s string) ([]Layer, error) {
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.system")
if s != "" {
blob := strings.NewReader(s)
layer, err := NewLayer(blob, "application/vnd.ollama.image.system")
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
if err != nil {
return nil, err
}
@@ -712,9 +713,9 @@ func setSystem(layers []Layer, s string) ([]Layer, error) {
return layers, nil
}
func setLicense(layers []Layer, l string) ([]Layer, error) {
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
blob := strings.NewReader(l)
layer, err := NewLayer(blob, "application/vnd.ollama.image.license")
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
if err != nil {
return nil, err
}
@@ -722,7 +723,7 @@ func setLicense(layers []Layer, l string) ([]Layer, error) {
return layers, nil
}
func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
if p == nil {
p = make(map[string]any)
}
@@ -731,7 +732,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
continue
}
digestPath, err := GetBlobsPath(layer.Digest)
digestPath, err := manifest.GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -765,7 +766,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
if err := json.NewEncoder(&b).Encode(p); err != nil {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}
@@ -773,7 +774,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
return layers, nil
}
func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
// this leaves the old messages intact if no new messages were specified
// which may not be the correct behaviour
if len(m) == 0 {
@@ -786,7 +787,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
if err := json.NewEncoder(&b).Encode(m); err != nil {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return nil, err
}
@@ -794,7 +795,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
return layers, nil
}
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
digests := make([]string, len(layers))
for i, layer := range layers {
digests[i] = layer.Digest
@@ -805,7 +806,7 @@ func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
if err := json.NewEncoder(&b).Encode(config); err != nil {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return nil, err
}

View File

@@ -10,6 +10,7 @@ import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
)
func TestConvertFromSafetensors(t *testing.T) {
@@ -17,7 +18,7 @@ func TestConvertFromSafetensors(t *testing.T) {
// Helper function to create a new layer and return its digest
makeTemp := func(content string) string {
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
if err != nil {
t.Fatalf("Failed to create layer: %v", err)
}

View File

@@ -24,6 +24,8 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
const maxRetries = 6
@@ -456,7 +458,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
}
type downloadOpts struct {
mp ModelPath
n model.Name
digest string
regOpts *registryOptions
fn func(api.ProgressResponse)
@@ -465,10 +467,10 @@ type downloadOpts struct {
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
if opts.digest == "" {
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
}
fp, err := GetBlobsPath(opts.digest)
fp, err := manifest.GetBlobsPath(opts.digest)
if err != nil {
return false, err
}
@@ -492,8 +494,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
download := data.(*blobDownload)
if !ok {
requestURL := opts.mp.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
requestURL := opts.n.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest)
return false, err

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -24,6 +23,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/gguf"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
@@ -274,44 +274,22 @@ func (m *Model) String() string {
return modelfile.String()
}
func GetManifest(mp ModelPath) (*Manifest, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
return nil, "", err
}
f, err := os.Open(fp)
if err != nil {
return nil, "", err
}
defer f.Close()
sha256sum := sha256.New()
var manifest Manifest
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
return nil, "", err
}
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
}
func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name)
manifest, digest, err := GetManifest(mp)
n := model.ParseName(name)
mf, err := manifest.ParseNamedManifest(n)
if err != nil {
return nil, err
}
model := &Model{
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
m := &Model{
Name: n.String(),
ShortName: n.DisplayShortest(),
Digest: mf.Digest(),
Template: template.DefaultTemplate,
}
if manifest.Config.Digest != "" {
filename, err := GetBlobsPath(manifest.Config.Digest)
if mf.Config.Digest != "" {
filename, err := manifest.GetBlobsPath(mf.Config.Digest)
if err != nil {
return nil, err
}
@@ -322,29 +300,29 @@ func GetModel(name string) (*Model, error) {
}
defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
return nil, err
}
}
for _, layer := range manifest.Layers {
filename, err := GetBlobsPath(layer.Digest)
for _, layer := range mf.Layers {
filename, err := manifest.GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
switch layer.MediaType {
case "application/vnd.ollama.image.model":
model.ModelPath = filename
model.ParentModel = layer.From
m.ModelPath = filename
m.ParentModel = layer.From
case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename)
m.AdapterPaths = append(m.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename)
m.ProjectorPaths = append(m.ProjectorPaths, filename)
case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
@@ -352,7 +330,7 @@ func GetModel(name string) (*Model, error) {
return nil, err
}
model.Template, err = template.Parse(string(bts))
m.Template, err = template.Parse(string(bts))
if err != nil {
return nil, err
}
@@ -362,7 +340,7 @@ func GetModel(name string) (*Model, error) {
return nil, err
}
model.System = string(bts)
m.System = string(bts)
case "application/vnd.ollama.image.params":
params, err := os.Open(filename)
if err != nil {
@@ -371,7 +349,7 @@ func GetModel(name string) (*Model, error) {
defer params.Close()
// parse model options parameters into a map so that we can see which fields have been specified explicitly
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
return nil, err
}
case "application/vnd.ollama.image.messages":
@@ -381,7 +359,7 @@ func GetModel(name string) (*Model, error) {
}
defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
return nil, err
}
case "application/vnd.ollama.image.license":
@@ -389,11 +367,11 @@ func GetModel(name string) (*Model, error) {
if err != nil {
return nil, err
}
model.License = append(model.License, string(bts))
m.License = append(m.License, string(bts))
}
}
return model, nil
return m, nil
}
func CopyModel(src, dst model.Name) error {
@@ -408,7 +386,7 @@ func CopyModel(src, dst model.Name) error {
return nil
}
manifests, err := GetManifestPath()
manifests, err := manifest.GetManifestPath()
if err != nil {
return err
}
@@ -437,7 +415,7 @@ func CopyModel(src, dst model.Name) error {
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
manifests, err := Manifests(true)
manifests, err := manifest.Manifests(true)
if err != nil {
return err
}
@@ -452,7 +430,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
// only delete the files which are still in the deleteMap
for k := range deleteMap {
fp, err := GetBlobsPath(k)
fp, err := manifest.GetBlobsPath(k)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
continue
@@ -468,7 +446,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
func PruneLayers() error {
deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("")
p, err := manifest.GetBlobsPath("")
if err != nil {
return err
}
@@ -483,9 +461,9 @@ func PruneLayers() error {
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
_, err := GetBlobsPath(name)
_, err := manifest.GetBlobsPath(name)
if err != nil {
if errors.Is(err, ErrInvalidDigestFormat) {
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
@@ -510,63 +488,30 @@ func PruneLayers() error {
return nil
}
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
n := model.ParseName(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
if n.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
}
manifest, _, err := GetManifest(mp)
mf, err := manifest.ParseNamedManifest(n)
if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
var layers []manifest.Layer
layers = append(layers, mf.Layers...)
if mf.Config.Digest != "" {
layers = append(layers, mf.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := mp.GetManifestPath()
manifestPath, err := manifest.GetManifestPathForName(n)
if err != nil {
return err
}
@@ -574,7 +519,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if err != nil {
return err
}
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
if err := pushWithTransfer(ctx, n, layers, manifestJSON, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
@@ -582,17 +527,17 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
if err := uploadBlob(ctx, n, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err
}
}
fn(api.ProgressResponse{Status: "pushing manifest"})
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
manifestJSON, err := json.Marshal(manifest)
manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
@@ -611,44 +556,44 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
n := model.ParseName(name)
// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
manifest, _, err := GetManifest(mp)
existingMf, err := manifest.ParseNamedManifest(n)
if errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil {
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
} else {
for _, l := range manifest.Layers {
for _, l := range existingMf.Layers {
deleteMap[l.Digest] = struct{}{}
}
if manifest.Config.Digest != "" {
deleteMap[manifest.Config.Digest] = struct{}{}
if existingMf.Config.Digest != "" {
deleteMap[existingMf.Config.Digest] = struct{}{}
}
}
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
if n.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
}
fn(api.ProgressResponse{Status: "pulling manifest"})
manifest, err = pullModelManifest(ctx, mp, regOpts)
mf, err := pullModelManifest(ctx, n, regOpts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
var layers []manifest.Layer
layers = append(layers, mf.Layers...)
if mf.Config.Digest != "" {
layers = append(layers, mf.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
if err := pullWithTransfer(ctx, n, layers, mf, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
@@ -658,7 +603,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
mp: mp,
n: n,
digest: layer.Digest,
regOpts: regOpts,
fn: fn,
@@ -677,7 +622,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
fp, err := GetBlobsPath(layer.Digest)
fp, err := manifest.GetBlobsPath(layer.Digest)
if err != nil {
return err
}
@@ -692,16 +637,16 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, layer := range layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
delete(deleteMap, mf.Config.Digest)
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
fp, err := manifest.GetManifestPathForName(n)
if err != nil {
return err
}
@@ -728,9 +673,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
// hasTensorLayers checks if any layer has tensor media type.
func hasTensorLayers(layers []Layer) bool {
func hasTensorLayers(layers []manifest.Layer) bool {
for _, layer := range layers {
if layer.MediaType == MediaTypeImageTensor {
if layer.MediaType == manifest.MediaTypeImageTensor {
return true
}
}
@@ -738,7 +683,7 @@ func hasTensorLayers(layers []Layer) bool {
}
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, mf *manifest.Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
@@ -747,12 +692,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
}
destDir, err := GetBlobsPath("")
destDir, err := manifest.GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
base := n.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
@@ -784,7 +729,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Blobs: blobs,
BaseURL: baseURL,
DestDir: destDir,
Repository: mp.GetNamespaceRepository(),
Repository: n.DisplayNamespaceModel(),
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
@@ -795,12 +740,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
fp, err := manifest.GetManifestPathForName(n)
if err != nil {
return err
}
@@ -812,7 +757,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
@@ -822,12 +767,12 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
}
srcDir, err := GetBlobsPath("")
srcDir, err := manifest.GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
base := n.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
@@ -864,13 +809,13 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
GetToken: getToken,
Logger: slog.Default(),
Manifest: manifestJSON,
ManifestRef: mp.Tag,
Repository: mp.GetNamespaceRepository(),
ManifestRef: n.Tag,
Repository: n.DisplayNamespaceModel(),
})
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
requestURL := n.BaseURL().JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
@@ -880,7 +825,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
}
defer resp.Body.Close()
var m Manifest
var m manifest.Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err
}
@@ -1042,7 +987,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
fp, err := manifest.GetBlobsPath(digest)
if err != nil {
return err
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
@@ -20,19 +21,19 @@ import (
var intermediateBlobs map[string]string = make(map[string]string)
type layerGGML struct {
Layer
manifest.Layer
*ggml.GGML
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
m, err := ParseNamedManifest(name)
m, err := manifest.ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}
m, err = ParseNamedManifest(name)
m, err = manifest.ParseNamedManifest(name)
if err != nil {
return nil, err
}
@@ -41,7 +42,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
}
for _, layer := range m.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil {
return nil, err
}
@@ -50,7 +51,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
case "application/vnd.ollama.image.model",
"application/vnd.ollama.image.projector",
"application/vnd.ollama.image.adapter":
blobpath, err := GetBlobsPath(layer.Digest)
blobpath, err := manifest.GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -81,12 +82,12 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err, "template", s)
} else {
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
layers = append(layers, &layerGGML{layer, nil})
if t.Parameters != nil {
@@ -95,7 +96,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}

View File

@@ -1,146 +0,0 @@
package server
import (
"errors"
"fmt"
"io/fs"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
type ModelPath struct {
ProtocolScheme string
Registry string
Namespace string
Repository string
Tag string
}
const (
DefaultRegistry = "registry.ollama.ai"
DefaultNamespace = "library"
DefaultTag = "latest"
DefaultProtocolScheme = "https"
)
var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidDigestFormat = errors.New("invalid digest format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
ErrModelPathInvalid = errors.New("invalid model path")
)
func ParseModelPath(name string) ModelPath {
mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "",
Tag: DefaultTag,
}
before, after, found := strings.Cut(name, "://")
if found {
mp.ProtocolScheme = before
name = after
}
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
case 2:
mp.Namespace = parts[0]
mp.Repository = parts[1]
case 1:
mp.Repository = parts[0]
}
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
mp.Repository = repo
mp.Tag = tag
}
return mp
}
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}
func (mp ModelPath) GetFullTagname() string {
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
func (mp ModelPath) GetShortTagname() string {
if mp.Registry == DefaultRegistry {
if mp.Namespace == DefaultNamespace {
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
name := model.Name{
Host: mp.Registry,
Namespace: mp.Namespace,
Model: mp.Repository,
Tag: mp.Tag,
}
if !name.IsValid() {
return "", fs.ErrNotExist
}
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
}
func (mp ModelPath) BaseURL() *url.URL {
return &url.URL{
Scheme: mp.ProtocolScheme,
Host: mp.Registry,
}
}
func GetManifestPath() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
func GetBlobsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "blobs", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}

View File

@@ -1,153 +0,0 @@
package server
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetBlobsPath(t *testing.T) {
// GetBlobsPath expects an actual directory to exist
tempDir := t.TempDir()
tests := []struct {
name string
digest string
expected string
err error
}{
{
"empty digest",
"",
filepath.Join(tempDir, "blobs"),
nil,
},
{
"valid with colon",
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"valid with dash",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"digest too short",
"sha256-45640291",
"",
ErrInvalidDigestFormat,
},
{
"digest too long",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
"",
ErrInvalidDigestFormat,
},
{
"digest invalid chars",
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
"",
ErrInvalidDigestFormat,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv("OLLAMA_MODELS", tempDir)
got, err := GetBlobsPath(tc.digest)
require.ErrorIs(t, tc.err, err, tc.name)
assert.Equal(t, tc.expected, got, tc.name)
})
}
}
func TestParseModelPath(t *testing.T) {
tests := []struct {
name string
arg string
want ModelPath
}{
{
"full path https",
"https://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"full path http",
"http://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "http",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no protocol",
"example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no registry",
"ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no namespace",
"repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: "tag",
},
},
{
"no tag",
"repo",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: DefaultTag,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ParseModelPath(tc.arg)
if got != tc.want {
t.Errorf("got: %q want: %q", got, tc.want)
}
})
}
}

View File

@@ -198,8 +198,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
name := t.Name
quantize := strings.HasSuffix(name, "weight")
// don't quantize vision stuff
quantize = quantize && (!strings.Contains(name, "v.") || strings.Contains(name, "_v."))
// don't quantize vision encoder tensors (named with "v." prefix)
quantize = quantize && !strings.HasPrefix(name, "v.")
quantize = quantize && !strings.Contains(name, "mm.")
// quantize only 2D and 3D tensors (experts)
@@ -219,6 +219,9 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
// NOTE: can't use LLM_TN here because the layer number is not known
quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight")
// do not quantize LFM2's shortconv kernel weights
quantize = quantize && !strings.Contains(name, "shortconv.conv.weight")
// do not quantize RWKV's time_mix_first tensors
quantize = quantize && !strings.Contains(name, "time_mix_first.weight")
quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")

View File

@@ -39,6 +39,7 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/middleware"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/model/renderers"
@@ -220,12 +221,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Handle image generation models
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
s.handleImageGenerate(c, req, name.String(), checkpointStart)
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
@@ -321,7 +316,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// expire the runner
// expire the runner if unload is requested (empty prompt, keep alive is 0)
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m)
@@ -335,6 +330,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Handle image generation models
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
s.handleImageGenerate(c, req, name.String(), checkpointStart)
return
}
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
@@ -974,7 +975,7 @@ func (s *Server) PushHandler(c *gin.Context) {
// is.
func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name
existing, err := Manifests(true)
existing, err := manifest.Manifests(true)
if err != nil {
return zero, err
}
@@ -1018,7 +1019,7 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return
}
m, err := ParseNamedManifest(n)
m, err := manifest.ParseNamedManifest(n)
if err != nil {
switch {
case os.IsNotExist(err):
@@ -1080,7 +1081,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, ErrModelPathInvalid
return nil, model.Unqualified(name)
}
name, err := getExistingName(name)
if err != nil {
@@ -1112,7 +1113,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For safetensors LLM models (experimental), populate details from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
modelDetails.Family = arch
}
@@ -1121,7 +1122,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
// Get torch_dtype directly from config.json for quantization level
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
@@ -1135,7 +1136,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}
manifest, err := ParseNamedManifest(name)
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, err
}
@@ -1147,8 +1148,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
Details: modelDetails,
Messages: msgs,
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
ModifiedAt: mf.FileInfo().ModTime(),
Requires: m.Config.Requires,
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
// default we return an empty map.
ModelInfo: make(map[string]any),
}
if m.Config.RemoteHost != "" {
@@ -1211,7 +1215,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
resp.Tensors = tensors
}
}
@@ -1220,12 +1224,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For safetensors LLM models (experimental), populate ModelInfo from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
resp.ModelInfo = info
}
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
resp.Tensors = tensors
}
}
@@ -1282,7 +1286,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
}
func (s *Server) ListHandler(c *gin.Context) {
ms, err := Manifests(true)
ms, err := manifest.Manifests(true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1313,8 +1317,8 @@ func (s *Server) ListHandler(c *gin.Context) {
RemoteModel: cf.RemoteModel,
RemoteHost: cf.RemoteHost,
Size: m.Size(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Digest: m.Digest(),
ModifiedAt: m.FileInfo().ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
@@ -1373,7 +1377,7 @@ func (s *Server) CopyHandler(c *gin.Context) {
}
func (s *Server) HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
path, err := manifest.GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1389,7 +1393,7 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
func (s *Server) CreateBlobHandler(c *gin.Context) {
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
p, err := GetBlobsPath(ib)
p, err := manifest.GetBlobsPath(ib)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1407,7 +1411,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
}
}
path, err := GetBlobsPath(c.Param("digest"))
path, err := manifest.GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1425,7 +1429,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return
}
layer, err := NewLayer(c.Request.Body, "")
layer, err := manifest.NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1625,7 +1629,7 @@ func Serve(ln net.Listener) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
slog.Info("server config", "env", envconfig.Values())
blobsDir, err := GetBlobsPath("")
blobsDir, err := manifest.GetBlobsPath("")
if err != nil {
return err
}
@@ -1634,7 +1638,7 @@ func Serve(ln net.Listener) error {
}
if !envconfig.NoPrune() {
if _, err := Manifests(false); err != nil {
if _, err := manifest.Manifests(false); err != nil {
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
} else {
// clean up unused layers and manifests
@@ -1642,12 +1646,12 @@ func Serve(ln net.Listener) error {
return err
}
manifestsPath, err := GetManifestPath()
manifestsPath, err := manifest.GetManifestPath()
if err != nil {
return err
}
if err := PruneDirectory(manifestsPath); err != nil {
if err := manifest.PruneDirectory(manifestsPath); err != nil {
return err
}
}

View File

@@ -25,6 +25,7 @@ import (
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
@@ -223,15 +224,15 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
manifest, err := ParseNamedManifest(model.ParseName("child"))
mf, err := manifest.ParseNamedManifest(model.ParseName("child"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
if manifest.Config.Digest == "" {
if mf.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for child manifest")
}
configPath, err := GetBlobsPath(manifest.Config.Digest)
configPath, err := manifest.GetBlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
@@ -93,13 +94,13 @@ func TestDeleteDuplicateLayers(t *testing.T) {
t.Fatal(err)
}
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
// create a manifest with duplicate layers
if err := WriteManifest(n, config, []Layer{config}); err != nil {
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
t.Fatal(err)
}

View File

@@ -2101,3 +2101,95 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
}
})
}
func TestGenerateUnload(t *testing.T) {
gin.SetMode(gin.TestMode)
var loadFnCalled bool
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mockRunner{}),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFnCalled = true
req.successCh <- &runnerRef{llama: &mockRunner{}}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
Files: map[string]string{"file.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("unload with empty prompt and keepalive 0", func(t *testing.T) {
loadFnCalled = false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "",
KeepAlive: &api.Duration{Duration: 0},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp api.GenerateResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.DoneReason != "unload" {
t.Errorf("expected done_reason 'unload', got %q", resp.DoneReason)
}
if !resp.Done {
t.Error("expected done to be true")
}
if loadFnCalled {
t.Error("expected model NOT to be loaded for unload request, but loadFn was called")
}
})
}

View File

@@ -571,6 +571,7 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
model: req.model,
modelPath: req.model.ModelPath,
llama: server,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
totalSize: server.TotalSize(),

View File

@@ -21,12 +21,14 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
var blobUploadManager sync.Map
type blobUpload struct {
Layer
manifest.Layer
Total int64
Completed atomic.Int64
@@ -51,7 +53,7 @@ const (
)
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
p, err := GetBlobsPath(b.Digest)
p, err := manifest.GetBlobsPath(b.Digest)
if err != nil {
return err
}
@@ -59,7 +61,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
if b.From != "" {
values := requestURL.Query()
values.Add("mount", b.Digest)
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
requestURL.RawQuery = values.Encode()
}
@@ -128,7 +130,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
p, err := GetBlobsPath(b.Digest)
p, err := manifest.GetBlobsPath(b.Digest)
if err != nil {
b.err = err
return
@@ -364,9 +366,9 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
switch {
@@ -388,8 +390,8 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOp
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
upload := data.(*blobUpload)
if !ok {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
blobUploadManager.Delete(layer.Digest)
return err

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"log/slog"
"net/url"
"path/filepath"
"strings"
)
@@ -35,22 +36,25 @@ func Unqualified(n Name) error {
const MissingPart = "!MISSING!"
const (
defaultHost = "registry.ollama.ai"
defaultNamespace = "library"
defaultTag = "latest"
defaultHost = "registry.ollama.ai"
defaultNamespace = "library"
defaultTag = "latest"
defaultProtocolScheme = "https"
)
// DefaultName returns a name with the default values for the host, namespace,
// and tag parts. The model and digest parts are empty.
// tag, and protocol scheme parts. The model and digest parts are empty.
//
// - The default host is ("registry.ollama.ai")
// - The default namespace is ("library")
// - The default tag is ("latest")
// - The default protocol scheme is ("https")
func DefaultName() Name {
return Name{
Host: defaultHost,
Namespace: defaultNamespace,
Tag: defaultTag,
Host: defaultHost,
Namespace: defaultNamespace,
Tag: defaultTag,
ProtocolScheme: defaultProtocolScheme,
}
}
@@ -87,10 +91,11 @@ func (k partKind) String() string {
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
// is valid.
type Name struct {
Host string
Namespace string
Model string
Tag string
Host string
Namespace string
Model string
Tag string
ProtocolScheme string
}
// ParseName parses and assembles a Name from a name string. The
@@ -160,7 +165,9 @@ func ParseNameBare(s string) Name {
}
scheme, host, ok := strings.Cut(s, "://")
if !ok {
if ok {
n.ProtocolScheme = scheme
} else {
host = scheme
}
n.Host = host
@@ -189,12 +196,13 @@ func ParseNameFromFilepath(s string) (n Name) {
return n
}
// Merge merges the host, namespace, and tag parts of the two names,
// Merge merges the host, namespace, tag, and protocol scheme parts of the two names,
// preferring the non-empty parts of a.
func Merge(a, b Name) Name {
a.Host = cmp.Or(a.Host, b.Host)
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
a.Tag = cmp.Or(a.Tag, b.Tag)
a.ProtocolScheme = cmp.Or(a.ProtocolScheme, b.ProtocolScheme)
return a
}
@@ -305,6 +313,23 @@ func (n Name) EqualFold(o Name) bool {
strings.EqualFold(n.Tag, o.Tag)
}
// BaseURL returns the base URL for the registry.
func (n Name) BaseURL() *url.URL {
return &url.URL{
Scheme: n.ProtocolScheme,
Host: n.Host,
}
}
// DisplayNamespaceModel returns the namespace and model joined by "/".
func (n Name) DisplayNamespaceModel() string {
var b strings.Builder
b.WriteString(n.Namespace)
b.WriteByte('/')
b.WriteString(n.Model)
return b.String()
}
func isValidLen(kind partKind, s string) bool {
switch kind {
case kindHost:

View File

@@ -32,10 +32,11 @@ func TestParseNameParts(t *testing.T) {
{
in: "scheme://host:port/namespace/model:tag",
want: Name{
Host: "host:port",
Namespace: "namespace",
Model: "model",
Tag: "tag",
Host: "host:port",
Namespace: "namespace",
Model: "model",
Tag: "tag",
ProtocolScheme: "scheme",
},
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
},

View File

@@ -12,8 +12,8 @@ import (
"fmt"
"io"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
)
@@ -103,7 +103,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
layer, err := manifest.NewLayer(r, mediaType)
if err != nil {
return create.LayerInfo{}, err
}
@@ -141,13 +141,13 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -169,7 +169,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -186,7 +186,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
// createUnquantizedLayer creates a single tensor layer without quantization.
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
layer, err := manifest.NewLayer(r, manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -221,15 +221,15 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
configLayer, err := manifest.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer
serverLayers := make([]server.Layer, 0, len(layers))
// Convert LayerInfo to manifest.Layer
manifestLayers := make([]manifest.Layer, 0, len(layers))
for _, l := range layers {
serverLayers = append(serverLayers, server.Layer{
manifestLayers = append(manifestLayers, manifest.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
@@ -243,19 +243,19 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
if err != nil {
return err
}
serverLayers = append(serverLayers, modelfileLayers...)
manifestLayers = append(manifestLayers, modelfileLayers...)
}
return server.WriteManifest(name, configLayer, serverLayers)
return manifest.WriteManifest(name, configLayer, manifestLayers)
}
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
var layers []server.Layer
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
var layers []manifest.Layer
if mf.Template != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
if err != nil {
return nil, fmt.Errorf("failed to create template layer: %w", err)
}
@@ -263,7 +263,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
}
if mf.System != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
if err != nil {
return nil, fmt.Errorf("failed to create system layer: %w", err)
}
@@ -271,7 +271,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
}
if mf.License != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
if err != nil {
return nil, fmt.Errorf("failed to create license layer: %w", err)
}

View File

@@ -54,6 +54,9 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
// Quantize based on quantization type
var qweight, scales, qbiases *mlx.Array
switch quantize {
case "fp4":
// affine mode: group_size=32, bits=4
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
case "fp8":
// affine mode: group_size=32, bits=8
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")

View File

@@ -20,10 +20,10 @@ import (
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
case "", "fp8":
case "", "fp4", "fp8":
// valid
default:
return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize)
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
}
var layers []LayerInfo

View File

@@ -7,12 +7,17 @@ import (
"encoding/json"
"flag"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"log"
"os"
"path/filepath"
"runtime/pprof"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
@@ -46,8 +51,8 @@ func main() {
imagePath := flag.String("image", "", "Image path for multimodal models")
// Image generation params
width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height")
width := flag.Int("width", 0, "Image width (0 = auto from input or 1024)")
height := flag.Int("height", 0, "Image height (0 = auto from input or 1024)")
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path")
@@ -61,6 +66,7 @@ func main() {
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
@@ -122,6 +128,44 @@ func main() {
if err == nil {
err = saveImageArray(img, *out)
}
case *flux2Flag:
m := &flux2.Model{}
if loadErr := m.Load(*modelPath); loadErr != nil {
log.Fatal(loadErr)
}
// Load input images with EXIF orientation correction
var loadedImages []image.Image
for _, path := range inputImages {
img, loadErr := loadImageWithEXIF(path)
if loadErr != nil {
log.Fatalf("Failed to load image %s: %v", path, loadErr)
}
loadedImages = append(loadedImages, img)
}
// When input images provided and user didn't override dimensions, use 0 to match input
fluxWidth := int32(*width)
fluxHeight := int32(*height)
if len(loadedImages) > 0 && *width == 0 && *height == 0 {
// Both unset, will auto-detect from input
} else if len(loadedImages) > 0 && *width == 0 {
fluxWidth = 0 // Compute from height + aspect ratio
} else if len(loadedImages) > 0 && *height == 0 {
fluxHeight = 0 // Compute from width + aspect ratio
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &flux2.GenerateConfig{
Prompt: *prompt,
Width: fluxWidth,
Height: fluxHeight,
Steps: *steps,
GuidanceScale: float32(*cfgScale),
Seed: *seed,
CapturePath: *gpuCapture,
InputImages: loadedImages,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
@@ -276,6 +320,8 @@ func detectModelKind(modelPath string) (string, error) {
switch index.ClassName {
case "FluxPipeline", "ZImagePipeline":
return "zimage", nil
case "Flux2KleinPipeline":
return "flux2", nil
}
}
return "zimage", nil
@@ -296,3 +342,12 @@ func detectModelKind(modelPath string) (string, error) {
return cfg.ModelType, nil
}
// loadImageWithEXIF loads an image from a file path with EXIF orientation correction.
func loadImageWithEXIF(path string) (image.Image, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read file: %w", err)
}
return imagegen.DecodeImage(data)
}

View File

@@ -7,6 +7,7 @@ import (
"encoding/base64"
"fmt"
"image"
_ "image/jpeg"
"image/png"
"os"
"path/filepath"
@@ -108,3 +109,160 @@ func clampF(v, min, max float32) float32 {
}
return v
}
// DecodeImage decodes image bytes with EXIF orientation applied.
func DecodeImage(data []byte) (image.Image, error) {
orientation := readJPEGOrientation(data)
img, _, err := image.Decode(bytes.NewReader(data))
if err != nil {
return nil, err
}
return applyOrientation(img, orientation), nil
}
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
// Returns 1 (normal) for non-JPEG or if orientation not found.
func readJPEGOrientation(data []byte) int {
if len(data) < 2 || data[0] != 0xFF || data[1] != 0xD8 {
return 1 // Not JPEG
}
r := bytes.NewReader(data[2:])
for {
var marker [2]byte
if _, err := r.Read(marker[:]); err != nil || marker[0] != 0xFF {
return 1
}
if marker[1] == 0xE1 { // APP1 (EXIF)
var lenBytes [2]byte
if _, err := r.Read(lenBytes[:]); err != nil {
return 1
}
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
if segLen < 14 {
r.Seek(int64(segLen), 1)
continue
}
seg := make([]byte, segLen)
if _, err := r.Read(seg); err != nil {
return 1
}
if string(seg[:4]) == "Exif" && seg[4] == 0 && seg[5] == 0 {
return parseTIFFOrientation(seg[6:])
}
continue
}
if marker[1] == 0xD9 || marker[1] == 0xDA {
return 1 // EOI or SOS
}
if marker[1] >= 0xD0 && marker[1] <= 0xD7 {
continue // RST markers
}
var lenBytes [2]byte
if _, err := r.Read(lenBytes[:]); err != nil {
return 1
}
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
if segLen > 0 {
r.Seek(int64(segLen), 1)
}
}
}
func parseTIFFOrientation(tiff []byte) int {
if len(tiff) < 8 {
return 1
}
var big bool
switch string(tiff[:2]) {
case "MM":
big = true
case "II":
big = false
default:
return 1
}
u16 := func(b []byte) uint16 {
if big {
return uint16(b[0])<<8 | uint16(b[1])
}
return uint16(b[1])<<8 | uint16(b[0])
}
u32 := func(b []byte) uint32 {
if big {
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
}
return uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
}
if u16(tiff[2:4]) != 42 {
return 1
}
ifdOffset := u32(tiff[4:8])
if int(ifdOffset)+2 > len(tiff) {
return 1
}
numEntries := u16(tiff[ifdOffset : ifdOffset+2])
for i := range int(numEntries) {
offset := ifdOffset + 2 + uint32(i)*12
if int(offset)+12 > len(tiff) {
break
}
if u16(tiff[offset:offset+2]) == 0x0112 { // Orientation tag
o := int(u16(tiff[offset+8 : offset+10]))
if o >= 1 && o <= 8 {
return o
}
return 1
}
}
return 1
}
func applyOrientation(img image.Image, orientation int) image.Image {
if orientation <= 1 || orientation > 8 {
return img
}
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
outW, outH := w, h
if orientation >= 5 {
outW, outH = h, w
}
out := image.NewRGBA(image.Rect(0, 0, outW, outH))
for y := range h {
for x := range w {
var dx, dy int
switch orientation {
case 2:
dx, dy = w-1-x, y
case 3:
dx, dy = w-1-x, h-1-y
case 4:
dx, dy = x, h-1-y
case 5:
dx, dy = y, x
case 6:
dx, dy = h-1-y, x
case 7:
dx, dy = h-1-y, w-1-x
case 8:
dx, dy = y, w-1-x
}
out.Set(dx, dy, img.At(x+bounds.Min.X, y+bounds.Min.Y))
}
}
return out
}

View File

@@ -6,8 +6,9 @@ import (
"io"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/ollama/ollama/envconfig"
)
// ManifestLayer represents a layer in the manifest.
@@ -32,31 +33,15 @@ type ModelManifest struct {
BlobDir string
}
// DefaultBlobDir returns the default blob storage directory.
func DefaultBlobDir() string {
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
switch runtime.GOOS {
case "darwin":
return filepath.Join(home, ".ollama", "models", "blobs")
case "linux":
return filepath.Join(home, ".ollama", "models", "blobs")
case "windows":
return filepath.Join(home, ".ollama", "models", "blobs")
default:
return filepath.Join(home, ".ollama", "models", "blobs")
}
return filepath.Join(envconfig.Models(), "blobs")
}
// DefaultManifestDir returns the default manifest storage directory.
// DefaultManifestDir returns the manifest storage directory.
// Respects OLLAMA_MODELS.
func DefaultManifestDir() string {
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
return filepath.Join(home, ".ollama", "models", "manifests")
return filepath.Join(envconfig.Models(), "manifests")
}
// LoadManifest loads a manifest for the given model name.

View File

@@ -0,0 +1,26 @@
package imagegen
import (
"path/filepath"
"testing"
)
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
modelsDir := filepath.Join(t.TempDir(), "models")
// Simulate packaged/systemd environment
t.Setenv("OLLAMA_MODELS", modelsDir)
t.Setenv("HOME", "/usr/share/ollama")
// Manifest dir must respect OLLAMA_MODELS
wantManifest := filepath.Join(modelsDir, "manifests")
if got := DefaultManifestDir(); got != wantManifest {
t.Fatalf("DefaultManifestDir() = %q, want %q", got, wantManifest)
}
// Blob dir must respect OLLAMA_MODELS
wantBlobs := filepath.Join(modelsDir, "blobs")
if got := DefaultBlobDir(); got != wantBlobs {
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
}
}

View File

@@ -24,9 +24,8 @@ var SupportedBackends = []string{"metal", "cuda", "cpu"}
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 20 * GB, // ~20GB for Flux
}
// CheckPlatformSupport validates that image generation is supported on the current platform.
@@ -72,26 +71,38 @@ func ResolveModelName(modelName string) string {
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
// Returns a conservative default of 21GB if the model type cannot be determined.
func EstimateVRAM(modelName string) uint64 {
manifest, err := LoadManifest(modelName)
if err != nil {
return 21 * GB
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return 21 * GB
}
// Parse just the class name
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err != nil {
return 21 * GB
}
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
className := DetectModelType(modelName)
if estimate, ok := modelVRAMEstimates[className]; ok {
return estimate
}
return 21 * GB
}
// DetectModelType reads model_index.json and returns the model type.
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
// Returns empty string if detection fails.
func DetectModelType(modelName string) string {
manifest, err := LoadManifest(modelName)
if err != nil {
return ""
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return ""
}
var index struct {
Architecture string `json:"architecture"`
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err != nil {
return ""
}
// Prefer architecture (Ollama format), fall back to _class_name (diffusers)
if index.Architecture != "" {
return index.Architecture
}
return index.ClassName
}

View File

@@ -72,9 +72,8 @@ func TestCheckMemoryRequirements(t *testing.T) {
func TestModelVRAMEstimates(t *testing.T) {
// Verify the VRAM estimates map has expected entries
expected := map[string]uint64{
"ZImagePipeline": 21 * GB,
"FluxPipeline": 21 * GB,
"QwenImagePipeline": 80 * GB,
"ZImagePipeline": 21 * GB,
"FluxPipeline": 20 * GB,
}
for name, expectedVRAM := range expected {

View File

@@ -1137,6 +1137,27 @@ func RMSNormNoWeight(x *Array, eps float32) *Array {
return RMSNorm(x, ones, eps)
}
// LayerNorm applies layer normalization without learnable params
// (x - mean) / sqrt(var + eps)
func LayerNorm(x *Array, eps float32) *Array {
return LayerNormWithWeightBias(x, nil, nil, eps)
}
// LayerNormWithWeightBias computes layer normalization using mlx.fast
// weight and bias can be nil for elementwise_affine=False
func LayerNormWithWeightBias(x, weight, bias *Array, eps float32) *Array {
res := C.mlx_array_new()
var wc, bc C.mlx_array
if weight != nil {
wc = weight.c
}
if bias != nil {
bc = bias.c
}
C.mlx_fast_layer_norm(&res, x.c, wc, bc, C.float(eps), C.default_stream())
return newArray(res)
}
// RoPE applies rotary position embeddings using mlx.fast
func RoPE(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
res := C.mlx_array_new()

View File

@@ -0,0 +1,539 @@
//go:build mlx
// Package flux2 implements the FLUX.2 Klein diffusion transformer model.
// Klein is a 4B parameter distilled model that supports sub-second inference.
package flux2
import (
"context"
"encoding/json"
"fmt"
"image"
"math"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen3"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"golang.org/x/image/draw"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 4 for Klein)
GuidanceScale float32 // Guidance scale (default: 1.0, Klein doesn't need CFG)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
CapturePath string // GPU capture path (debug)
InputImages []image.Image // Reference images for image conditioning (already loaded)
}
// Model represents a FLUX.2 Klein model.
type Model struct {
ModelName string
Tokenizer *tokenizer.Tokenizer
TextEncoder *qwen3.TextEncoder
Transformer *Flux2Transformer2DModel
VAE *AutoencoderKLFlux2
SchedulerConfig *SchedulerConfig
}
// TextEncoderLayerIndices are the layers from which to extract text embeddings.
// Diffusers uses hidden_states[9, 18, 27]. In Python, hidden_states[0] is the embedding
// output before any layers, so hidden_states[9] = after layer 8 (0-indexed).
// Go's ForwardWithLayerOutputs captures after layer i runs, so we use [8, 17, 26].
var TextEncoderLayerIndices = []int{8, 17, 26}
// Load loads the FLUX.2 Klein model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading FLUX.2 Klein model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{}
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
tokConfig.SpecialTokensMapJSON = data
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder
m.TextEncoder = &qwen3.TextEncoder{}
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
// Load transformer
m.Transformer = &Flux2Transformer2DModel{}
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
// Load VAE
m.VAE = &AutoencoderKLFlux2{}
if err := m.VAE.Load(manifest); err != nil {
return fmt.Errorf("VAE: %w", err)
}
// Evaluate all weights in a single batch (reduces GPU sync overhead)
fmt.Print(" Evaluating weights... ")
allWeights := mlx.Collect(m.TextEncoder)
allWeights = append(allWeights, mlx.Collect(m.Transformer)...)
allWeights = append(allWeights, mlx.Collect(m.VAE)...)
mlx.Eval(allWeights...)
fmt.Println("✓")
// Load scheduler config
m.SchedulerConfig = DefaultSchedulerConfig()
if schedData, err := manifest.ReadConfig("scheduler/scheduler_config.json"); err == nil {
if err := json.Unmarshal(schedData, m.SchedulerConfig); err != nil {
fmt.Printf(" Warning: failed to parse scheduler config: %v\n", err)
}
}
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
return result, nil
}
// GenerateImage implements runner.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
const MaxOutputPixels = 2048 * 2048
// MaxRefPixels is the maximum resolution for reference images (smaller to reduce attention memory)
const MaxRefPixels = 728 * 728
// generate is the internal denoising pipeline.
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Enable MLX compilation for fused kernels
mlx.EnableCompile()
// Apply defaults
if cfg.Steps <= 0 {
cfg.Steps = 4 // Klein default: 4 steps for distilled model
}
if cfg.GuidanceScale <= 0 {
cfg.GuidanceScale = 1.0 // Klein doesn't need guidance
}
// Determine output dimensions
if len(cfg.InputImages) > 0 {
// With input images, compute missing dimension from aspect ratio
// Images are already EXIF-rotated by the caller
bounds := cfg.InputImages[0].Bounds()
imgW, imgH := bounds.Dx(), bounds.Dy()
aspectRatio := float64(imgH) / float64(imgW)
if cfg.Width > 0 && cfg.Height <= 0 {
// Width specified, compute height
cfg.Height = int32(math.Round(float64(cfg.Width)*aspectRatio/16) * 16)
} else if cfg.Height > 0 && cfg.Width <= 0 {
// Height specified, compute width
cfg.Width = int32(math.Round(float64(cfg.Height)/aspectRatio/16) * 16)
} else if cfg.Width <= 0 && cfg.Height <= 0 {
// Neither specified, use input dimensions
cfg.Width = int32(imgW)
cfg.Height = int32(imgH)
}
}
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
// Cap to max pixels, preserve aspect ratio, round to multiple of 16
pixels := int(cfg.Width) * int(cfg.Height)
if pixels > MaxOutputPixels {
scale := math.Sqrt(float64(MaxOutputPixels) / float64(pixels))
cfg.Width = int32(math.Round(float64(cfg.Width) * scale / 16) * 16)
cfg.Height = int32(math.Round(float64(cfg.Height) * scale / 16) * 16)
}
cfg.Height = int32((cfg.Height + 8) / 16 * 16) // round to nearest 16
cfg.Width = int32((cfg.Width + 8) / 16 * 16)
fmt.Printf(" Output: %dx%d\n", cfg.Width, cfg.Height)
tcfg := m.Transformer.TransformerConfig
patchSize := m.VAE.Config.PatchSize
// Latent dimensions: image / 8 (VAE downscale) / patch_size
latentH := cfg.Height / 8
latentW := cfg.Width / 8
patchH := latentH / patchSize[0]
patchW := latentW / patchSize[1]
imgSeqLen := patchH * patchW
// Text encoding with multi-layer extraction (no padding, use true sequence length)
fmt.Print(" Encoding prompt... ")
promptEmbeds, textLen := m.TextEncoder.EncodePromptWithLayers(m.Tokenizer, cfg.Prompt, 512, TextEncoderLayerIndices, false)
fmt.Println("✓")
// Encode reference images if provided
var refTokens *ImageCondTokens
var refHeights, refWidths []int32
if len(cfg.InputImages) > 0 {
fmt.Printf(" Encoding %d reference image(s):\n", len(cfg.InputImages))
var err error
refTokens, err = m.EncodeImageRefs(cfg.InputImages)
if err != nil {
return nil, fmt.Errorf("encode reference images: %w", err)
}
// Extract heights/widths for RoPE computation (same limits as EncodeImageRefs)
limitPixels := MaxRefPixels
if len(cfg.InputImages) > 1 {
limitPixels = MaxRefPixels / 2
}
for _, img := range cfg.InputImages {
_, w, h := PrepareImage(img, limitPixels)
refHeights = append(refHeights, int32(h/16))
refWidths = append(refWidths, int32(w/16))
}
}
// Scheduler
scheduler := NewFlowMatchScheduler(m.SchedulerConfig)
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(imgSeqLen, cfg.Steps))
// Init latents in packed form [B, C*4, H/2, W/2] like diffusers
// diffusers creates noise in [B, 128, 64, 64] and packs to [B, 4096, 128]
latentChannels := m.VAE.Config.LatentChannels
packedChannels := latentChannels * 4 // 32 * 4 = 128
latents := scheduler.InitNoise([]int32{1, packedChannels, patchH, patchW}, cfg.Seed)
// Pack latents (transpose): [B, C, H, W] -> [B, H*W, C]
// This matches diffusers' _pack_latents
patches := packLatents(latents)
noiseSeqLen := patches.Shape()[1]
// RoPE cache - includes reference images if present
rope := PrepareRoPECache(textLen, patchH, patchW, tcfg.AxesDimsRoPE, tcfg.RopeTheta, refHeights, refWidths, ImageRefScale)
// Cleanup setup arrays when done
defer func() {
rope.Cos.Free()
rope.Sin.Free()
promptEmbeds.Free()
if refTokens != nil {
refTokens.Tokens.Free()
}
}()
// Pre-compute all timesteps before the loop to avoid per-step tensor creation
timesteps := make([]*mlx.Array, cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
tCurr := scheduler.Timesteps[i] / float32(m.SchedulerConfig.NumTrainTimesteps)
timesteps[i] = mlx.ToBFloat16(mlx.NewArray([]float32{tCurr}, []int32{1}))
}
// Evaluate setup arrays
fmt.Print(" Evaluating setup... ")
setupStart := time.Now()
toEval := []*mlx.Array{promptEmbeds, patches, rope.Cos, rope.Sin}
toEval = append(toEval, timesteps...)
if refTokens != nil {
toEval = append(toEval, refTokens.Tokens)
}
mlx.Eval(toEval...)
mlx.MetalResetPeakMemory() // Reset peak to measure generation separately
fmt.Printf("✓ (%.2fs, %.1f GB)\n", time.Since(setupStart).Seconds(),
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
if cfg.Progress != nil {
cfg.Progress(0, cfg.Steps)
}
loopStart := time.Now()
stepStart := time.Now()
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
// GPU capture on step 2 if requested
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStartCapture(cfg.CapturePath)
}
timestep := timesteps[i]
// Prepare input - concatenate noise patches with reference tokens if present
imgInput := patches
if refTokens != nil {
imgInput = mlx.Concatenate([]*mlx.Array{patches, refTokens.Tokens}, 1)
}
// Transformer forward pass
output := m.Transformer.Forward(imgInput, promptEmbeds, timestep, rope)
// If we concatenated reference tokens, slice to only get noise portion
if refTokens != nil {
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, noiseSeqLen, output.Shape()[2]})
}
// Scheduler step (keep reference to old patches for the computation graph)
newPatches := scheduler.Step(output, patches, i)
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStopCapture()
}
mlx.Eval(newPatches)
patches = newPatches
elapsed := time.Since(stepStart).Seconds()
peakGB := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
if i == 0 {
fmt.Printf(" step %d: %.2fs (JIT warmup), peak %.1f GB\n", i+1, elapsed, peakGB)
} else {
fmt.Printf(" step %d: %.2fs, peak %.1f GB\n", i+1, elapsed, peakGB)
}
stepStart = time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
}
loopTime := time.Since(loopStart).Seconds()
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Denoised %d steps in %.2fs (%.2fs/step), peak %.1f GB\n",
cfg.Steps, loopTime, loopTime/float64(cfg.Steps), peakMem)
// Free timesteps now that denoising is done
for _, ts := range timesteps {
ts.Free()
}
// VAE decode with tiling for larger images
fmt.Print(" Decoding VAE... ")
vaeStart := time.Now()
// Enable tiling for images > 512x512 (latent > 64x64)
// VAE attention is O(n²) on latent pixels, tiling reduces memory significantly
if patchH*2 > 64 || patchW*2 > 64 {
m.VAE.Tiling = DefaultTilingConfig()
}
decoded := m.VAE.Decode(patches, patchH, patchW)
mlx.Eval(decoded)
// Free patches now that decode is done
patches.Free()
fmt.Printf("✓ (%.2fs, peak %.1f GB)\n", time.Since(vaeStart).Seconds(),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// packLatents converts [B, C, H, W] to [B, H*W, C] (matches diffusers _pack_latents)
func packLatents(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// [B, C, H, W] -> [B, C, H*W] -> [B, H*W, C]
x = mlx.Reshape(x, B, C, H*W)
return mlx.Transpose(x, 0, 2, 1)
}
// LoadPersistent loads the model and keeps it in memory for repeated use.
func LoadPersistent(modelName string) (*Model, error) {
m := &Model{}
if err := m.Load(modelName); err != nil {
return nil, err
}
return m, nil
}
// ImageRefScale is the time coordinate offset between reference images (matches diffusers scale=10)
const ImageRefScale = 10
// PrepareImage resizes and crops an image to be a multiple of 16, with optional pixel limit.
// Returns the processed image and its dimensions.
func PrepareImage(img image.Image, limitPixels int) (image.Image, int, int) {
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
// Cap pixels if needed (like diffusers cap_pixels)
if limitPixels > 0 && w*h > limitPixels {
scale := math.Sqrt(float64(limitPixels) / float64(w*h))
w = int(float64(w) * scale)
h = int(float64(h) * scale)
}
// Round down to multiple of 16
w = (w / 16) * 16
h = (h / 16) * 16
if w < 16 {
w = 16
}
if h < 16 {
h = 16
}
// Resize using high-quality bicubic interpolation (matches diffusers' default lanczos)
resized := image.NewRGBA(image.Rect(0, 0, w, h))
draw.CatmullRom.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
return resized, w, h
}
// ImageToTensor converts an image to a tensor in [-1, 1] range with shape [1, C, H, W].
func ImageToTensor(img image.Image) *mlx.Array {
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
// Convert to float32 array in NCHW format [1, 3, H, W] with values in [-1, 1]
data := make([]float32, 3*h*w)
for y := 0; y < h; y++ {
for x := 0; x < w; x++ {
r, g, b, _ := img.At(x+bounds.Min.X, y+bounds.Min.Y).RGBA()
// RGBA returns 16-bit values, convert to [-1, 1]
data[0*h*w+y*w+x] = float32(r>>8)/127.5 - 1.0
data[1*h*w+y*w+x] = float32(g>>8)/127.5 - 1.0
data[2*h*w+y*w+x] = float32(b>>8)/127.5 - 1.0
}
}
arr := mlx.NewArrayFloat32(data, []int32{1, 3, int32(h), int32(w)})
return arr
}
// ImageCondTokens holds encoded reference image tokens.
type ImageCondTokens struct {
Tokens *mlx.Array // [1, total_tokens, C] - concatenated reference tokens
}
// EncodeImageRefs encodes reference images using the VAE.
func (m *Model) EncodeImageRefs(images []image.Image) (*ImageCondTokens, error) {
if len(images) == 0 {
return nil, nil
}
// Limit reference images to reduce attention memory
limitPixels := MaxRefPixels
if len(images) > 1 {
limitPixels = MaxRefPixels / 2
}
var allTokens []*mlx.Array
for _, img := range images {
// Prepare image (resize, crop to multiple of 16)
prepared, prepW, prepH := PrepareImage(img, limitPixels)
fmt.Printf(" Encoding %dx%d image... ", prepW, prepH)
// Convert to tensor [-1, 1]
tensor := ImageToTensor(prepared)
// Encode with VAE - returns [1, L, 128]
encoded := m.VAE.EncodeImage(tensor)
squeezed := mlx.Squeeze(encoded, 0) // [L, C]
// Defer eval - will be done with other setup arrays
allTokens = append(allTokens, squeezed)
fmt.Println("✓")
}
// For single image, just add batch dimension directly
// For multiple images, concatenate first
var tokens *mlx.Array
if len(allTokens) == 1 {
tokens = mlx.ExpandDims(allTokens[0], 0) // [1, L, C]
} else {
tokens = mlx.Concatenate(allTokens, 0) // [total_L, C]
tokens = mlx.ExpandDims(tokens, 0) // [1, total_L, C]
}
return &ImageCondTokens{Tokens: tokens}, nil
}

View File

@@ -0,0 +1,224 @@
//go:build mlx
package flux2
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// RoPEConfig holds 4D RoPE configuration for Flux2
type RoPEConfig struct {
Theta int32 // 2000 for Klein
AxesDims []int32 // [32, 32, 32, 32] - dimensions for T, H, W, L axes
}
// RoPECache holds precomputed RoPE cos/sin values
type RoPECache struct {
Cos *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
Sin *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
TextLen int32 // Length of text sequence
ImageLen int32 // Length of image sequence
}
// PrepareTextIDs creates position IDs for text tokens.
// Text tokens use: T=0, H=0, W=0, L=0..seqLen-1
// Returns: [seqLen, 4]
func PrepareTextIDs(seqLen int32) *mlx.Array {
ids := make([]float32, seqLen*4)
for i := int32(0); i < seqLen; i++ {
idx := i * 4
ids[idx+0] = 0 // T = 0
ids[idx+1] = 0 // H = 0
ids[idx+2] = 0 // W = 0
ids[idx+3] = float32(i) // L = sequence position
}
return mlx.NewArray(ids, []int32{seqLen, 4})
}
// PrepareLatentIDs creates position IDs for image latent tokens.
// Latent tokens use: T=0, H=0..height-1, W=0..width-1, L=0
// The latents are in row-major order (H then W).
// Returns: [height*width, 4]
func PrepareLatentIDs(height, width int32) *mlx.Array {
seqLen := height * width
ids := make([]float32, seqLen*4)
idx := 0
for h := int32(0); h < height; h++ {
for w := int32(0); w < width; w++ {
ids[idx*4+0] = 0 // T = 0
ids[idx*4+1] = float32(h) // H = row
ids[idx*4+2] = float32(w) // W = column
ids[idx*4+3] = 0 // L = 0
idx++
}
}
return mlx.NewArray(ids, []int32{seqLen, 4})
}
// PrepareImageIDs creates position IDs for reference image tokens (used in editing).
// Reference images use: T=scale*(i+1), H=0..h-1, W=0..w-1, L=0
// where i is the image index (0, 1, 2, ...) and scale separates images in T dimension.
// Returns: [total_tokens, 4]
func PrepareImageIDs(imageHeights, imageWidths []int32, scale int32) *mlx.Array {
// Calculate total tokens
totalTokens := int32(0)
for i := range imageHeights {
totalTokens += imageHeights[i] * imageWidths[i]
}
ids := make([]float32, totalTokens*4)
idx := int32(0)
for imgIdx, h := range imageHeights {
w := imageWidths[imgIdx]
tValue := float32(scale * int32(imgIdx+1))
for hi := int32(0); hi < h; hi++ {
for wi := int32(0); wi < w; wi++ {
ids[idx*4+0] = tValue // T = scale * (imgIdx + 1)
ids[idx*4+1] = float32(hi) // H = row
ids[idx*4+2] = float32(wi) // W = column
ids[idx*4+3] = 0 // L = 0
idx++
}
}
}
return mlx.NewArray(ids, []int32{totalTokens, 4})
}
// ComputeRoPE computes cos and sin for 4D rotary position embeddings.
// ids: [L, 4] with (T, H, W, L) coordinates
// axesDims: [32, 32, 32, 32] - each axis has this many dimensions (total = head_dim = 128)
// theta: base frequency (2000 for Klein)
// Returns: cos, sin each [1, L, 1, head_dim] with repeat_interleave applied
func ComputeRoPE(ids *mlx.Array, axesDims []int32, theta int32) (*mlx.Array, *mlx.Array) {
shape := ids.Shape()
seqLen := shape[0]
// Compute total head dim (sum of all axes dims)
headDim := int32(0)
for _, d := range axesDims {
headDim += d
}
// Extract each coordinate dimension
// ids[:, 0] = T, ids[:, 1] = H, ids[:, 2] = W, ids[:, 3] = L
posT := mlx.Slice(ids, []int32{0, 0}, []int32{seqLen, 1}) // [L, 1]
posH := mlx.Slice(ids, []int32{0, 1}, []int32{seqLen, 2}) // [L, 1]
posW := mlx.Slice(ids, []int32{0, 2}, []int32{seqLen, 3}) // [L, 1]
posL := mlx.Slice(ids, []int32{0, 3}, []int32{seqLen, 4}) // [L, 1]
// Compute frequencies for each axis
logTheta := float32(math.Log(float64(theta)))
cosArrs := make([]*mlx.Array, 4)
sinArrs := make([]*mlx.Array, 4)
positions := []*mlx.Array{posT, posH, posW, posL}
for i, axisDim := range axesDims {
half := axisDim / 2
// Create frequency array for this axis: theta^(-2j/dim) for j=0..half-1
// This matches diffusers: 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
freqs := make([]float32, half)
for j := int32(0); j < half; j++ {
freqs[j] = float32(math.Exp(float64(-logTheta * float32(2*j) / float32(axisDim))))
}
freqArr := mlx.NewArray(freqs, []int32{1, half})
// Compute pos * freq -> [L, half]
posExpanded := positions[i] // [L, 1]
args := mlx.Mul(posExpanded, freqArr) // [L, half]
// Compute cos and sin for this axis
cosAxis := mlx.Cos(args) // [L, half]
sinAxis := mlx.Sin(args) // [L, half]
// repeat_interleave(2): [c0, c1, ...] -> [c0, c0, c1, c1, ...]
// Reshape [L, half] -> [L, half, 1], tile to [L, half, 2], reshape to [L, axisDim]
cosAxis = mlx.ExpandDims(cosAxis, 2) // [L, half, 1]
cosAxis = mlx.Tile(cosAxis, []int32{1, 1, 2}) // [L, half, 2]
cosAxis = mlx.Reshape(cosAxis, seqLen, axisDim) // [L, axisDim]
sinAxis = mlx.ExpandDims(sinAxis, 2)
sinAxis = mlx.Tile(sinAxis, []int32{1, 1, 2})
sinAxis = mlx.Reshape(sinAxis, seqLen, axisDim)
cosArrs[i] = cosAxis
sinArrs[i] = sinAxis
}
// Concatenate all axes: [L, headDim]
cos := mlx.Concatenate(cosArrs, 1)
sin := mlx.Concatenate(sinArrs, 1)
// Reshape to [1, L, 1, headDim] for broadcasting with attention
cos = mlx.Reshape(cos, 1, seqLen, 1, headDim)
sin = mlx.Reshape(sin, 1, seqLen, 1, headDim)
return cos, sin
}
// ApplyRoPE4D applies 4D rotary position embeddings to queries and keys.
// x: [B, L, nheads, head_dim]
// cos, sin: [1, L, 1, head_dim] (with repeat_interleave applied)
// Returns: x with RoPE applied
// Matches diffusers apply_rotary_emb with use_real=True, use_real_unbind_dim=-1
func ApplyRoPE4D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
half := headDim / 2
// Reshape x to [B, L, nheads, half, 2] and split into real/imag
xReshaped := mlx.Reshape(x, B, L, nheads, half, 2)
// Extract real (index 0) and imag (index 1) parts
xReal := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, half, 1})
xImag := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, half, 2})
xReal = mlx.Squeeze(xReal, 4) // [B, L, nheads, half]
xImag = mlx.Squeeze(xImag, 4) // [B, L, nheads, half]
// x_rotated = stack([-x_imag, x_real], dim=-1).flatten(-2)
// This creates [-x_imag[0], x_real[0], -x_imag[1], x_real[1], ...]
negXImag := mlx.Neg(xImag)
negXImag = mlx.ExpandDims(negXImag, 4) // [B, L, nheads, half, 1]
xReal = mlx.ExpandDims(xReal, 4) // [B, L, nheads, half, 1]
xRotated := mlx.Concatenate([]*mlx.Array{negXImag, xReal}, 4) // [B, L, nheads, half, 2]
xRotated = mlx.Reshape(xRotated, B, L, nheads, headDim) // [B, L, nheads, headDim]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(xRotated, sin))
}
// PrepareRoPECache creates RoPE cache for text + noise, optionally with reference images.
// textLen: number of text tokens
// noiseH, noiseW: dimensions of the noise latent in patch tokens
// axesDims: [32, 32, 32, 32]
// theta: 2000
// refHeights, refWidths: optional reference image dimensions (pass nil/empty for no images)
// scale: time coordinate offset between reference images (e.g., 10)
func PrepareRoPECache(textLen, noiseH, noiseW int32, axesDims []int32, theta int32, refHeights, refWidths []int32, scale int32) *RoPECache {
textIDs := PrepareTextIDs(textLen)
noiseIDs := PrepareLatentIDs(noiseH, noiseW)
var allIDs *mlx.Array
imageLen := noiseH * noiseW
if len(refHeights) > 0 {
refIDs := PrepareImageIDs(refHeights, refWidths, scale)
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs, refIDs}, 0)
for i := range refHeights {
imageLen += refHeights[i] * refWidths[i]
}
} else {
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs}, 0)
}
cos, sin := ComputeRoPE(allIDs, axesDims, theta)
cos = mlx.ToBFloat16(cos)
sin = mlx.ToBFloat16(sin)
return &RoPECache{Cos: cos, Sin: sin, TextLen: textLen, ImageLen: imageLen}
}

View File

@@ -0,0 +1,149 @@
//go:build mlx
package flux2
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds Flow-Match scheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
Shift float32 `json:"shift"` // 3.0 for Klein
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
TimeShiftType string `json:"time_shift_type"` // "exponential" or "linear"
}
// DefaultSchedulerConfig returns default config for Klein
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
Shift: 3.0, // Klein uses 3.0
UseDynamicShifting: true,
TimeShiftType: "exponential",
}
}
// FlowMatchScheduler implements the Flow-Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32 // Discretized timesteps (t from 1 to 0)
Sigmas []float32 // Noise levels at each timestep
NumSteps int // Number of inference steps
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// SetTimesteps sets up the scheduler for the given number of inference steps
func (s *FlowMatchScheduler) SetTimesteps(numSteps int) {
s.SetTimestepsWithMu(numSteps, 0)
}
// SetTimestepsWithMu sets up scheduler matching diffusers set_timesteps(sigmas=..., mu=...)
func (s *FlowMatchScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
s.NumSteps = numSteps
// diffusers: sigmas = linspace(1, 1/num_steps, num_steps)
// Then applies time shift, appends 0.0 at end
s.Sigmas = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
// linspace(1, 1/num_steps, num_steps)
var sigma float32
if numSteps == 1 {
sigma = 1.0
} else {
sigma = 1.0 - float32(i)/float32(numSteps-1)*(1.0-1.0/float32(numSteps))
}
// Apply time shift if using dynamic shifting
if s.Config.UseDynamicShifting && mu != 0 {
sigma = s.timeShift(mu, sigma)
} else {
// If not dynamic shifting, apply fixed shift scaling like diffusers
shift := s.Config.Shift
sigma = shift * sigma / (1 + (shift-1)*sigma)
}
s.Sigmas[i] = sigma
}
// Append terminal zero
s.Sigmas[numSteps] = 0.0
// Timesteps scaled to training range (matches diffusers: timesteps = sigmas * num_train_timesteps)
s.Timesteps = make([]float32, numSteps+1)
for i, v := range s.Sigmas {
s.Timesteps[i] = v * float32(s.Config.NumTrainTimesteps)
}
}
// timeShift applies the dynamic time shift
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
if s.Config.TimeShiftType == "linear" {
return mu / (mu + (1.0/t-1.0))
}
// Default: exponential
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 for precision (matches diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
outputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(outputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to bfloat16
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// CalculateShift computes the mu shift value for dynamic scheduling
// Matches diffusers compute_empirical_mu function
func CalculateShift(imgSeqLen int32, numSteps int) float32 {
a1, b1 := float32(8.73809524e-05), float32(1.89833333)
a2, b2 := float32(0.00016927), float32(0.45666666)
seqLen := float32(imgSeqLen)
if imgSeqLen > 4300 {
return a2*seqLen + b2
}
m200 := a2*seqLen + b2
m10 := a1*seqLen + b1
a := (m200 - m10) / 190.0
b := m200 - 200.0*a
return a*float32(numSteps) + b
}

View File

@@ -0,0 +1,562 @@
//go:build mlx
package flux2
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Flux2 transformer configuration
type TransformerConfig struct {
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
}
// Computed dimensions
func (c *TransformerConfig) InnerDim() int32 {
return c.NumAttentionHeads * c.AttentionHeadDim // 24 * 128 = 3072
}
func (c *TransformerConfig) MLPHiddenDim() int32 {
return int32(float32(c.InnerDim()) * c.MLPRatio) // 3072 * 3.0 = 9216
}
// TimestepEmbedder creates timestep embeddings
// Weight names: time_guidance_embed.timestep_embedder.linear_1.weight, linear_2.weight
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"linear_1"`
Linear2 nn.LinearLayer `weight:"linear_2"`
EmbedDim int32 // 256
}
// Forward creates sinusoidal embeddings and projects them
func (t *TimestepEmbedder) Forward(timesteps *mlx.Array) *mlx.Array {
half := t.EmbedDim / 2
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
// timesteps: [B] -> [B, 1]
tExpanded := mlx.ExpandDims(timesteps, 1)
// args: [B, half]
args := mlx.Mul(tExpanded, freqsArr)
// [cos(args), sin(args)] -> [B, embed_dim]
sinEmbed := mlx.Concatenate([]*mlx.Array{mlx.Cos(args), mlx.Sin(args)}, 1)
// MLP: linear_1 -> silu -> linear_2
h := t.Linear1.Forward(sinEmbed)
h = mlx.SiLU(h)
return t.Linear2.Forward(h)
}
// TimeGuidanceEmbed wraps the timestep embedder
// Weight names: time_guidance_embed.timestep_embedder.*
type TimeGuidanceEmbed struct {
TimestepEmbedder *TimestepEmbedder `weight:"timestep_embedder"`
}
// Forward computes timestep embeddings
func (t *TimeGuidanceEmbed) Forward(timesteps *mlx.Array) *mlx.Array {
return t.TimestepEmbedder.Forward(timesteps)
}
// Modulation computes adaptive modulation parameters
// Weight names: double_stream_modulation_img.linear.weight, etc.
type Modulation struct {
Linear nn.LinearLayer `weight:"linear"`
}
// Forward computes modulation parameters
func (m *Modulation) Forward(temb *mlx.Array) *mlx.Array {
h := mlx.SiLU(temb)
return m.Linear.Forward(h)
}
// TransformerBlockAttn implements dual-stream attention
// Weight names: transformer_blocks.N.attn.*
type TransformerBlockAttn struct {
// Image stream (separate Q, K, V projections)
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
// Note: to_out has .0 suffix in weights, handled specially
ToOut0 nn.LinearLayer `weight:"to_out.0"`
// Text stream (add_ projections)
AddQProj nn.LinearLayer `weight:"add_q_proj"`
AddKProj nn.LinearLayer `weight:"add_k_proj"`
AddVProj nn.LinearLayer `weight:"add_v_proj"`
ToAddOut nn.LinearLayer `weight:"to_add_out"`
// QK norms for image stream
NormQ *mlx.Array `weight:"norm_q.weight"`
NormK *mlx.Array `weight:"norm_k.weight"`
// QK norms for text stream (added)
NormAddedQ *mlx.Array `weight:"norm_added_q.weight"`
NormAddedK *mlx.Array `weight:"norm_added_k.weight"`
}
// FeedForward implements SwiGLU MLP
// Weight names: transformer_blocks.N.ff.linear_in.weight, linear_out.weight
type FeedForward struct {
LinearIn nn.LinearLayer `weight:"linear_in"`
LinearOut nn.LinearLayer `weight:"linear_out"`
}
// Forward applies SwiGLU MLP
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// LinearIn outputs 2x hidden dim for SwiGLU
h := ff.LinearIn.Forward(x)
shape := h.Shape()
half := shape[len(shape)-1] / 2
// Split into gate and up
gate := mlx.Slice(h, []int32{0, 0, 0}, []int32{shape[0], shape[1], half})
up := mlx.Slice(h, []int32{0, 0, half}, []int32{shape[0], shape[1], shape[2]})
// SwiGLU: silu(gate) * up
h = mlx.Mul(mlx.SiLU(gate), up)
return ff.LinearOut.Forward(h)
}
// TransformerBlock implements a dual-stream transformer block
// Weight names: transformer_blocks.N.*
type TransformerBlock struct {
Attn *TransformerBlockAttn `weight:"attn"`
FF *FeedForward `weight:"ff"`
FFContext *FeedForward `weight:"ff_context"`
// Config (set after loading)
NHeads int32
HeadDim int32
Scale float32
}
// Forward applies the dual-stream block
// imgHidden: [B, imgLen, dim]
// txtHidden: [B, txtLen, dim]
// imgMod, txtMod: modulation params [B, 6*dim] each
// cos, sin: RoPE values
func (block *TransformerBlock) Forward(imgHidden, txtHidden *mlx.Array, imgMod, txtMod *mlx.Array, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) {
imgShape := imgHidden.Shape()
B := imgShape[0]
imgLen := imgShape[1]
dim := imgShape[2]
txtLen := txtHidden.Shape()[1]
// Parse modulation: 6 params each (shift1, scale1, gate1, shift2, scale2, gate2)
imgShift1, imgScale1, imgGate1 := parseModulation3(imgMod, dim, 0)
imgShift2, imgScale2, imgGate2 := parseModulation3(imgMod, dim, 3)
txtShift1, txtScale1, txtGate1 := parseModulation3(txtMod, dim, 0)
txtShift2, txtScale2, txtGate2 := parseModulation3(txtMod, dim, 3)
// === Attention branch ===
// Modulate inputs
imgNorm := modulateLayerNorm(imgHidden, imgShift1, imgScale1)
txtNorm := modulateLayerNorm(txtHidden, txtShift1, txtScale1)
// Compute Q, K, V for image stream (separate projections)
imgQ := block.Attn.ToQ.Forward(imgNorm)
imgK := block.Attn.ToK.Forward(imgNorm)
imgV := block.Attn.ToV.Forward(imgNorm)
// Compute Q, K, V for text stream (add_ projections)
txtQ := block.Attn.AddQProj.Forward(txtNorm)
txtK := block.Attn.AddKProj.Forward(txtNorm)
txtV := block.Attn.AddVProj.Forward(txtNorm)
// Reshape for attention: [B, L, dim] -> [B, L, nheads, headDim]
imgQ = mlx.Reshape(imgQ, B, imgLen, block.NHeads, block.HeadDim)
imgK = mlx.Reshape(imgK, B, imgLen, block.NHeads, block.HeadDim)
imgV = mlx.Reshape(imgV, B, imgLen, block.NHeads, block.HeadDim)
txtQ = mlx.Reshape(txtQ, B, txtLen, block.NHeads, block.HeadDim)
txtK = mlx.Reshape(txtK, B, txtLen, block.NHeads, block.HeadDim)
txtV = mlx.Reshape(txtV, B, txtLen, block.NHeads, block.HeadDim)
// Apply QK norm (RMSNorm with learned scale)
imgQ = applyQKNorm(imgQ, block.Attn.NormQ)
imgK = applyQKNorm(imgK, block.Attn.NormK)
txtQ = applyQKNorm(txtQ, block.Attn.NormAddedQ)
txtK = applyQKNorm(txtK, block.Attn.NormAddedK)
// Concatenate for joint attention: text first, then image
q := mlx.Concatenate([]*mlx.Array{txtQ, imgQ}, 1)
k := mlx.Concatenate([]*mlx.Array{txtK, imgK}, 1)
v := mlx.Concatenate([]*mlx.Array{txtV, imgV}, 1)
// Apply RoPE
q = ApplyRoPE4D(q, cos, sin)
k = ApplyRoPE4D(k, cos, sin)
// Transpose for SDPA: [B, nheads, L, headDim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Scaled dot-product attention
out := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
// Transpose back: [B, L, nheads, headDim]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Split back into txt and img
totalLen := txtLen + imgLen
txtOut := mlx.Slice(out, []int32{0, 0, 0, 0}, []int32{B, txtLen, block.NHeads, block.HeadDim})
imgOut := mlx.Slice(out, []int32{0, txtLen, 0, 0}, []int32{B, totalLen, block.NHeads, block.HeadDim})
// Reshape and project
txtOut = mlx.Reshape(txtOut, B, txtLen, dim)
imgOut = mlx.Reshape(imgOut, B, imgLen, dim)
txtOut = block.Attn.ToAddOut.Forward(txtOut)
imgOut = block.Attn.ToOut0.Forward(imgOut)
// Apply gates and residual
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate1, imgOut))
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate1, txtOut))
// === MLP branch ===
imgNorm = modulateLayerNorm(imgHidden, imgShift2, imgScale2)
txtNorm = modulateLayerNorm(txtHidden, txtShift2, txtScale2)
imgFFOut := block.FF.Forward(imgNorm)
txtFFOut := block.FFContext.Forward(txtNorm)
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate2, imgFFOut))
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate2, txtFFOut))
return imgHidden, txtHidden
}
// SingleTransformerBlockAttn implements attention for single-stream blocks
// Weight names: single_transformer_blocks.N.attn.*
type SingleTransformerBlockAttn struct {
ToQKVMlpProj nn.LinearLayer `weight:"to_qkv_mlp_proj"` // Fused QKV + MLP input
ToOut nn.LinearLayer `weight:"to_out"` // Fused attn_out + MLP out
NormQ *mlx.Array `weight:"norm_q.weight"`
NormK *mlx.Array `weight:"norm_k.weight"`
}
// SingleTransformerBlock implements a single-stream transformer block
// Weight names: single_transformer_blocks.N.*
type SingleTransformerBlock struct {
Attn *SingleTransformerBlockAttn `weight:"attn"`
// Config
NHeads int32
HeadDim int32
InnerDim int32
MLPHidDim int32
Scale float32
}
// Forward applies the single-stream block
// x: [B, L, dim] concatenated text+image
// mod: modulation [B, 3*dim]
func (block *SingleTransformerBlock) Forward(x *mlx.Array, mod *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
dim := shape[2]
// Parse modulation: (shift, scale, gate)
shift, scale, gate := parseModulation3(mod, dim, 0)
// Modulate input
h := modulateLayerNorm(x, shift, scale)
// Fused projection: QKV + MLP gate/up
// linear1 outputs: [q, k, v, mlp_gate, mlp_up] = [dim, dim, dim, mlpHid, mlpHid]
qkvMlp := block.Attn.ToQKVMlpProj.Forward(h)
// Split: first 3*dim is QKV, rest is MLP
qkvDim := 3 * block.InnerDim
qkv := mlx.Slice(qkvMlp, []int32{0, 0, 0}, []int32{B, L, qkvDim})
mlpIn := mlx.Slice(qkvMlp, []int32{0, 0, qkvDim}, []int32{B, L, qkvMlp.Shape()[2]})
// Split QKV
q, k, v := splitQKV(qkv, B, L, block.InnerDim)
// Reshape for attention
q = mlx.Reshape(q, B, L, block.NHeads, block.HeadDim)
k = mlx.Reshape(k, B, L, block.NHeads, block.HeadDim)
v = mlx.Reshape(v, B, L, block.NHeads, block.HeadDim)
// QK norm
q = applyQKNorm(q, block.Attn.NormQ)
k = applyQKNorm(k, block.Attn.NormK)
// Apply RoPE
q = ApplyRoPE4D(q, cos, sin)
k = ApplyRoPE4D(k, cos, sin)
// Transpose for SDPA
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// SDPA
attnOut := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
// Transpose back and reshape
attnOut = mlx.Transpose(attnOut, 0, 2, 1, 3)
attnOut = mlx.Reshape(attnOut, B, L, block.InnerDim)
// MLP: SwiGLU
mlpShape := mlpIn.Shape()
half := mlpShape[2] / 2
mlpGate := mlx.Slice(mlpIn, []int32{0, 0, 0}, []int32{B, L, half})
mlpUp := mlx.Slice(mlpIn, []int32{0, 0, half}, []int32{B, L, mlpShape[2]})
mlpOut := mlx.Mul(mlx.SiLU(mlpGate), mlpUp)
// Concatenate attention and MLP for fused output
combined := mlx.Concatenate([]*mlx.Array{attnOut, mlpOut}, 2)
// Output projection
out := block.Attn.ToOut.Forward(combined)
// Apply gate and residual
return mlx.Add(x, mlx.Mul(gate, out))
}
// NormOut implements the output normalization with modulation
// Weight names: norm_out.linear.weight
type NormOut struct {
Linear nn.LinearLayer `weight:"linear"`
}
// Forward computes final modulated output
func (n *NormOut) Forward(x *mlx.Array, temb *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
dim := shape[2]
// Modulation: temb -> silu -> linear -> [shift, scale]
mod := mlx.SiLU(temb)
mod = n.Linear.Forward(mod)
// Split into scale and shift (diffusers order: scale first, shift second)
scale := mlx.Slice(mod, []int32{0, 0}, []int32{B, dim})
shift := mlx.Slice(mod, []int32{0, dim}, []int32{B, 2 * dim})
shift = mlx.ExpandDims(shift, 1)
scale = mlx.ExpandDims(scale, 1)
// Modulate with RMSNorm
return modulateLayerNorm(x, shift, scale)
}
// Flux2Transformer2DModel is the main Flux2 transformer
// Weight names at top level: time_guidance_embed.*, double_stream_modulation_*.*, etc.
type Flux2Transformer2DModel struct {
// Timestep embedding
TimeGuidanceEmbed *TimeGuidanceEmbed `weight:"time_guidance_embed"`
// Shared modulation
DoubleStreamModulationImg *Modulation `weight:"double_stream_modulation_img"`
DoubleStreamModulationTxt *Modulation `weight:"double_stream_modulation_txt"`
SingleStreamModulation *Modulation `weight:"single_stream_modulation"`
// Embedders
XEmbedder nn.LinearLayer `weight:"x_embedder"`
ContextEmbedder nn.LinearLayer `weight:"context_embedder"`
// Transformer blocks
TransformerBlocks []*TransformerBlock `weight:"transformer_blocks"`
SingleTransformerBlocks []*SingleTransformerBlock `weight:"single_transformer_blocks"`
// Output
NormOut *NormOut `weight:"norm_out"`
ProjOut nn.LinearLayer `weight:"proj_out"`
*TransformerConfig
}
// Load loads the Flux2 transformer from ollama blob storage.
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = &cfg
// Initialize slices
m.TransformerBlocks = make([]*TransformerBlock, cfg.NumLayers)
m.SingleTransformerBlocks = make([]*SingleTransformerBlock, cfg.NumSingleLayers)
// Initialize TimeGuidanceEmbed with embed dim
m.TimeGuidanceEmbed = &TimeGuidanceEmbed{
TimestepEmbedder: &TimestepEmbedder{EmbedDim: cfg.TimestepGuidanceChannels},
}
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Flux2Transformer2DModel) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *Flux2Transformer2DModel) initComputedFields() {
cfg := m.TransformerConfig
innerDim := cfg.InnerDim()
scale := float32(1.0 / math.Sqrt(float64(cfg.AttentionHeadDim)))
// Initialize transformer blocks
for _, block := range m.TransformerBlocks {
block.NHeads = cfg.NumAttentionHeads
block.HeadDim = cfg.AttentionHeadDim
block.Scale = scale
}
// Initialize single transformer blocks
for _, block := range m.SingleTransformerBlocks {
block.NHeads = cfg.NumAttentionHeads
block.HeadDim = cfg.AttentionHeadDim
block.InnerDim = innerDim
block.MLPHidDim = cfg.MLPHiddenDim()
block.Scale = scale
}
}
// Forward runs the Flux2 transformer
func (m *Flux2Transformer2DModel) Forward(patches, txtEmbeds *mlx.Array, timesteps *mlx.Array, rope *RoPECache) *mlx.Array {
patchShape := patches.Shape()
B := patchShape[0]
imgLen := patchShape[1]
txtLen := txtEmbeds.Shape()[1]
// Scale timestep to 0-1000 range (diffusers multiplies by 1000)
scaledTimesteps := mlx.MulScalar(timesteps, 1000.0)
// Compute timestep embedding
temb := m.TimeGuidanceEmbed.Forward(scaledTimesteps)
// Embed patches and text
imgHidden := m.XEmbedder.Forward(patches)
txtHidden := m.ContextEmbedder.Forward(txtEmbeds)
// Compute shared modulation
imgMod := m.DoubleStreamModulationImg.Forward(temb)
txtMod := m.DoubleStreamModulationTxt.Forward(temb)
singleMod := m.SingleStreamModulation.Forward(temb)
// Double (dual-stream) blocks
for _, block := range m.TransformerBlocks {
imgHidden, txtHidden = block.Forward(imgHidden, txtHidden, imgMod, txtMod, rope.Cos, rope.Sin)
}
// Concatenate for single-stream: text first, then image
hidden := mlx.Concatenate([]*mlx.Array{txtHidden, imgHidden}, 1)
// Single-stream blocks
for _, block := range m.SingleTransformerBlocks {
hidden = block.Forward(hidden, singleMod, rope.Cos, rope.Sin)
}
// Extract image portion
totalLen := txtLen + imgLen
imgOut := mlx.Slice(hidden, []int32{0, txtLen, 0}, []int32{B, totalLen, hidden.Shape()[2]})
// Final norm and projection
imgOut = m.NormOut.Forward(imgOut, temb)
return m.ProjOut.Forward(imgOut)
}
// Note: QK normalization uses mlx.RMSNorm (the fast version) directly
// See applyQKNorm function below
// compiledSwiGLU fuses: silu(gate) * up
// Called 30x per step (10 in dual-stream + 20 in single-stream blocks)
var compiledSwiGLU *mlx.CompiledFunc
func getCompiledSwiGLU() *mlx.CompiledFunc {
if compiledSwiGLU == nil {
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
gate, up := inputs[0], inputs[1]
return []*mlx.Array{mlx.Mul(mlx.SiLU(gate), up)}
}, true)
}
return compiledSwiGLU
}
// Helper functions
// parseModulation3 extracts 3 modulation params (shift, scale, gate) starting at offset
func parseModulation3(mod *mlx.Array, dim int32, offset int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
B := mod.Shape()[0]
start := offset * dim
shift := mlx.Slice(mod, []int32{0, start}, []int32{B, start + dim})
scale := mlx.Slice(mod, []int32{0, start + dim}, []int32{B, start + 2*dim})
gate := mlx.Slice(mod, []int32{0, start + 2*dim}, []int32{B, start + 3*dim})
// Expand for broadcasting [B, dim] -> [B, 1, dim]
shift = mlx.ExpandDims(shift, 1)
scale = mlx.ExpandDims(scale, 1)
gate = mlx.ExpandDims(gate, 1)
return shift, scale, gate
}
// modulateLayerNorm applies LayerNorm then shift/scale modulation
// Diffusers uses LayerNorm(elementwise_affine=False) which centers the data
func modulateLayerNorm(x *mlx.Array, shift, scale *mlx.Array) *mlx.Array {
// Fast LayerNorm without learnable params
x = mlx.LayerNorm(x, 1e-6)
// Modulate: x * (1 + scale) + shift
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
return mlx.Add(x, shift)
}
// splitQKV splits a fused QKV tensor into Q, K, V
func splitQKV(qkv *mlx.Array, B, L, dim int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
q := mlx.Slice(qkv, []int32{0, 0, 0}, []int32{B, L, dim})
k := mlx.Slice(qkv, []int32{0, 0, dim}, []int32{B, L, 2 * dim})
v := mlx.Slice(qkv, []int32{0, 0, 2 * dim}, []int32{B, L, 3 * dim})
return q, k, v
}
// applyQKNorm applies RMSNorm with learned scale (no bias)
// Uses the optimized mlx_fast_rms_norm
func applyQKNorm(x *mlx.Array, scale *mlx.Array) *mlx.Array {
return mlx.RMSNorm(x, scale, 1e-6)
}

View File

@@ -0,0 +1,804 @@
//go:build mlx
package flux2
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/vae"
)
// VAEConfig holds AutoencoderKLFlux2 configuration
type VAEConfig struct {
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
}
// BatchNorm2D implements 2D batch normalization with running statistics
type BatchNorm2D struct {
RunningMean *mlx.Array // [C]
RunningVar *mlx.Array // [C]
Weight *mlx.Array // [C] gamma
Bias *mlx.Array // [C] beta
Eps float32
Momentum float32
}
// Forward applies batch normalization (inference mode - uses running stats)
// Input and output are in NHWC format [B, H, W, C]
func (bn *BatchNorm2D) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[3]
// Reshape stats for broadcasting [1, 1, 1, C]
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
// Normalize: (x - mean) / sqrt(var + eps)
xNorm := mlx.Sub(x, mean)
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
// Scale and shift (only if affine=True)
if bn.Weight != nil {
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if bn.Bias != nil {
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Denormalize inverts the batch normalization
// Used when decoding latents
func (bn *BatchNorm2D) Denormalize(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[3]
// Reshape stats for broadcasting [1, 1, 1, C]
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
// Inverse: first undo affine, then undo normalization
// For affine=False: x_denorm = x * sqrt(var + eps) + mean
if bn.Bias != nil {
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
x = mlx.Sub(x, bias)
}
if bn.Weight != nil {
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
x = mlx.Div(x, weight)
}
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
x = mlx.Add(x, mean)
return x
}
// GroupNormLayer implements group normalization
// Reused from zimage package pattern
type GroupNormLayer struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
NumGroups int32
Eps float32
}
// Forward applies group normalization
// Input and output are in NHWC format [B, H, W, C]
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
// Reshape to [B, H, W, groups, C/groups]
groupSize := C / gn.NumGroups
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
// Compute mean and variance per group
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
sq := mlx.Square(xCentered)
variance := mlx.Mean(sq, 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back to [B, H, W, C]
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Conv2D represents a 2D convolution layer (reused pattern)
type Conv2D struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias,optional"`
Stride int32
Padding int32
}
// Transform implements safetensors.Transformer to transpose weights from PyTorch's OIHW to MLX's OHWI.
func (conv *Conv2D) Transform(field string, arr *mlx.Array) *mlx.Array {
if field == "Weight" {
return mlx.Transpose(arr, 0, 2, 3, 1)
}
return arr
}
// Forward applies convolution (NHWC format)
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
if conv.Bias != nil {
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
out = mlx.Add(out, bias)
}
return out
}
// ResnetBlock2D implements a ResNet block for VAE
type ResnetBlock2D struct {
Norm1 *GroupNormLayer `weight:"norm1"`
Conv1 *Conv2D `weight:"conv1"`
Norm2 *GroupNormLayer `weight:"norm2"`
Conv2 *Conv2D `weight:"conv2"`
ConvShortcut *Conv2D `weight:"conv_shortcut,optional"`
}
// Forward applies the ResNet block
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
h := rb.Norm1.Forward(x)
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
h = rb.Norm2.Forward(h)
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
if rb.ConvShortcut != nil {
x = rb.ConvShortcut.Forward(x)
}
return mlx.Add(h, x)
}
// VAEAttentionBlock implements self-attention for VAE
type VAEAttentionBlock struct {
GroupNorm *GroupNormLayer `weight:"group_norm"`
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
}
// Forward applies attention (NHWC format)
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
h := ab.GroupNorm.Forward(x)
h = mlx.Reshape(h, B, H*W, C)
q := ab.ToQ.Forward(h)
k := ab.ToK.Forward(h)
v := ab.ToV.Forward(h)
q = mlx.ExpandDims(q, 1)
k = mlx.ExpandDims(k, 1)
v = mlx.ExpandDims(v, 1)
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Squeeze(out, 1)
out = ab.ToOut.Forward(out)
out = mlx.Reshape(out, B, H, W, C)
out = mlx.Add(out, residual)
return out
}
// UpDecoderBlock2D implements an upsampling decoder block
type UpDecoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Upsample *Conv2D
}
// Forward applies the up decoder block
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range ub.ResnetBlocks {
x = resnet.Forward(x)
}
if ub.Upsample != nil {
x = upsample2x(x)
x = ub.Upsample.Forward(x)
}
return x
}
// upsample2x performs 2x nearest neighbor upsampling
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
hIdx = mlx.Reshape(hIdx, H, 1)
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
hIdx = mlx.Reshape(hIdx, H*2)
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
wIdx = mlx.Reshape(wIdx, W, 1)
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
wIdx = mlx.Reshape(wIdx, W*2)
x = mlx.Take(x, hIdx, 1)
x = mlx.Take(x, wIdx, 2)
return x
}
// VAEMidBlock is the middle block with attention
type VAEMidBlock struct {
Resnet1 *ResnetBlock2D
Attention *VAEAttentionBlock
Resnet2 *ResnetBlock2D
}
// Forward applies the mid block
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
x = mb.Resnet1.Forward(x)
x = mb.Attention.Forward(x)
x = mb.Resnet2.Forward(x)
return x
}
// DefaultTilingConfig returns reasonable defaults for tiled decoding
// Matches diffusers: tile_latent_min_size=64, tile_overlap_factor=0.25
func DefaultTilingConfig() *vae.TilingConfig {
return vae.DefaultTilingConfig()
}
// AutoencoderKLFlux2 is the Flux2 VAE with BatchNorm
type AutoencoderKLFlux2 struct {
Config *VAEConfig
// Encoder components (for image editing)
EncoderConvIn *Conv2D
EncoderMid *VAEMidBlock
EncoderDown []*DownEncoderBlock2D
EncoderNormOut *GroupNormLayer
EncoderConvOut *Conv2D
// Decoder components
DecoderConvIn *Conv2D
DecoderMid *VAEMidBlock
DecoderUp []*UpDecoderBlock2D
DecoderNormOut *GroupNormLayer
DecoderConvOut *Conv2D
// Quant conv layers
QuantConv *Conv2D
PostQuantConv *Conv2D
// BatchNorm for latent normalization
LatentBN *BatchNorm2D
// Tiling configuration (nil = no tiling)
Tiling *vae.TilingConfig
}
// DownEncoderBlock2D implements a downsampling encoder block
type DownEncoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Downsample *Conv2D
}
// Forward applies the down encoder block
func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range db.ResnetBlocks {
x = resnet.Forward(x)
}
if db.Downsample != nil {
// Pad then conv with stride 2
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0})
x = db.Downsample.Forward(x)
}
return x
}
// Load loads the Flux2 VAE from ollama blob storage.
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading VAE... ")
// Load config from blob
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights, &cfg)
}
// loadWeights loads VAE weights from any WeightSource
func (m *AutoencoderKLFlux2) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
var err error
// Load encoder components (for image conditioning)
if err := m.loadEncoderWeights(weights, cfg); err != nil {
return fmt.Errorf("encoder: %w", err)
}
// Load decoder conv_in
m.DecoderConvIn = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.DecoderConvIn, weights, "decoder.conv_in"); err != nil {
return fmt.Errorf("decoder.conv_in: %w", err)
}
// Load mid block
m.DecoderMid, err = loadVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
if err != nil {
return fmt.Errorf("decoder.mid_block: %w", err)
}
// Load up blocks
numBlocks := len(cfg.BlockOutChannels)
m.DecoderUp = make([]*UpDecoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
hasUpsample := i < numBlocks-1
m.DecoderUp[i], err = loadUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
if err != nil {
return fmt.Errorf("%s: %w", prefix, err)
}
}
// Load decoder conv_norm_out and conv_out
m.DecoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
if err := safetensors.LoadModule(m.DecoderNormOut, weights, "decoder.conv_norm_out"); err != nil {
return fmt.Errorf("decoder.conv_norm_out: %w", err)
}
m.DecoderConvOut = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.DecoderConvOut, weights, "decoder.conv_out"); err != nil {
return fmt.Errorf("decoder.conv_out: %w", err)
}
// Load post_quant_conv
if cfg.UsePostQuantConv {
m.PostQuantConv = &Conv2D{Stride: 1, Padding: 0}
if err := safetensors.LoadModule(m.PostQuantConv, weights, "post_quant_conv"); err != nil {
return fmt.Errorf("post_quant_conv: %w", err)
}
}
// Load latent BatchNorm (affine=False, so no weight/bias)
bnMean, err := weights.GetTensor("bn.running_mean")
if err != nil {
return fmt.Errorf("bn.running_mean: %w", err)
}
bnVar, err := weights.GetTensor("bn.running_var")
if err != nil {
return fmt.Errorf("bn.running_var: %w", err)
}
m.LatentBN = &BatchNorm2D{
RunningMean: bnMean,
RunningVar: bnVar,
Weight: nil, // affine=False
Bias: nil, // affine=False
Eps: cfg.BatchNormEps,
Momentum: cfg.BatchNormMomentum,
}
fmt.Println("✓")
return nil
}
// loadVAEMidBlock loads the mid block.
func loadVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
resnet1, err := loadResnetBlock2D(weights, prefix+".resnets.0", numGroups)
if err != nil {
return nil, err
}
attention, err := loadVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
if err != nil {
return nil, err
}
resnet2, err := loadResnetBlock2D(weights, prefix+".resnets.1", numGroups)
if err != nil {
return nil, err
}
return &VAEMidBlock{
Resnet1: resnet1,
Attention: attention,
Resnet2: resnet2,
}, nil
}
// loadResnetBlock2D loads a ResNet block.
func loadResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
block := &ResnetBlock2D{
Norm1: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
Conv1: &Conv2D{Stride: 1, Padding: 1},
Norm2: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
Conv2: &Conv2D{Stride: 1, Padding: 1},
ConvShortcut: &Conv2D{Stride: 1, Padding: 0}, // Pre-allocate for optional loading
}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, err
}
// If ConvShortcut wasn't loaded (no weights found), nil it out
if block.ConvShortcut.Weight == nil {
block.ConvShortcut = nil
}
return block, nil
}
// loadVAEAttentionBlock loads an attention block using LoadModule.
func loadVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
ab := &VAEAttentionBlock{
GroupNorm: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
}
if err := safetensors.LoadModule(ab, weights, prefix); err != nil {
return nil, err
}
return ab, nil
}
// loadUpDecoderBlock2D loads an up decoder block.
func loadUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var upsample *Conv2D
if hasUpsample {
upsample = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(upsample, weights, prefix+".upsamplers.0.conv"); err != nil {
return nil, err
}
}
return &UpDecoderBlock2D{
ResnetBlocks: resnets,
Upsample: upsample,
}, nil
}
// Patchify converts latents [B, C, H, W] to patches [B, H*W/4, C*4] using 2x2 patches
// This is the inverse of the VAE's patchify for feeding to transformer
func (vae *AutoencoderKLFlux2) Patchify(latents *mlx.Array) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
patchH := vae.Config.PatchSize[0]
patchW := vae.Config.PatchSize[1]
pH := H / patchH
pW := W / patchW
// [B, C, H, W] -> [B, C, pH, patchH, pW, patchW]
x := mlx.Reshape(latents, B, C, pH, patchH, pW, patchW)
// [B, C, pH, patchH, pW, patchW] -> [B, pH, pW, C, patchH, patchW]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// [B, pH, pW, C, patchH, patchW] -> [B, pH*pW, C*patchH*patchW]
return mlx.Reshape(x, B, pH*pW, C*patchH*patchW)
}
// Unpatchify converts patches [B, L, C*4] back to [B, C, H, W]
func (vae *AutoencoderKLFlux2) Unpatchify(patches *mlx.Array, pH, pW, C int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
patchH := vae.Config.PatchSize[0]
patchW := vae.Config.PatchSize[1]
// [B, pH*pW, C*patchH*patchW] -> [B, pH, pW, C, patchH, patchW]
x := mlx.Reshape(patches, B, pH, pW, C, patchH, patchW)
// [B, pH, pW, C, patchH, patchW] -> [B, C, pH, patchH, pW, patchW]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// [B, C, pH, patchH, pW, patchW] -> [B, C, H, W]
H := pH * patchH
W := pW * patchW
return mlx.Reshape(x, B, C, H, W)
}
// denormalizePatchified applies inverse batch normalization to patchified latents.
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
// Output: [B, L, 128] denormalized
func (vae *AutoencoderKLFlux2) denormalizePatchified(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[2] // 128
// Reshape stats for broadcasting [1, 1, C]
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
// Inverse BN (affine=False): x_denorm = x * sqrt(var + eps) + mean
if vae.LatentBN.Bias != nil {
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
x = mlx.Sub(x, bias)
}
if vae.LatentBN.Weight != nil {
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
x = mlx.Div(x, weight)
}
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
x = mlx.Add(x, mean)
return x
}
// Decode decodes latent patches to images.
// If Tiling is set, uses tiled decoding to reduce memory for large images.
// latents: [B, L, C*4] patchified latents from transformer
// pH, pW: patch grid dimensions
// Returns: [B, 3, H, W] image tensor
func (v *AutoencoderKLFlux2) Decode(latents *mlx.Array, pH, pW int32) *mlx.Array {
// Denormalize patchified latents
z := v.denormalizePatchified(latents)
// Unpatchify: [B, L, C*4] -> [B, C, H, W]
z = v.Unpatchify(z, pH, pW, v.Config.LatentChannels)
// Convert NCHW -> NHWC for processing
z = mlx.Transpose(z, 0, 2, 3, 1)
// Use tiled decoding if enabled
if v.Tiling != nil {
mlx.Eval(z)
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
}
// Direct decode (no tiling)
h := v.decodeTile(z)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
h = mlx.Transpose(h, 0, 3, 1, 2)
return h
}
// decodeTile decodes a single latent tile to pixels (internal helper)
// z: [B, H, W, C] latent tile in NHWC format
// Returns: [B, H*8, W*8, 3] pixel tile in NHWC format (before clipping)
func (vae *AutoencoderKLFlux2) decodeTile(z *mlx.Array) *mlx.Array {
// Post-quant conv
if vae.PostQuantConv != nil {
z = vae.PostQuantConv.Forward(z)
}
// Decoder
h := vae.DecoderConvIn.Forward(z)
h = vae.DecoderMid.Forward(h)
for _, upBlock := range vae.DecoderUp {
h = upBlock.Forward(h)
}
h = vae.DecoderNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.DecoderConvOut.Forward(h)
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
h = mlx.AddScalar(h, 0.5)
return h
}
// loadEncoderWeights loads the encoder components for image conditioning
func (m *AutoencoderKLFlux2) loadEncoderWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
var err error
// Load encoder conv_in
m.EncoderConvIn = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.EncoderConvIn, weights, "encoder.conv_in"); err != nil {
return fmt.Errorf("encoder.conv_in: %w", err)
}
// Load encoder down blocks
numBlocks := len(cfg.BlockOutChannels)
m.EncoderDown = make([]*DownEncoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("encoder.down_blocks.%d", i)
hasDownsample := i < numBlocks-1
m.EncoderDown[i], err = loadDownEncoderBlock2D(weights, prefix, cfg.LayersPerBlock, cfg.NormNumGroups, hasDownsample)
if err != nil {
return fmt.Errorf("%s: %w", prefix, err)
}
}
// Load encoder mid block
m.EncoderMid, err = loadVAEMidBlock(weights, "encoder.mid_block", cfg.NormNumGroups)
if err != nil {
return fmt.Errorf("encoder.mid_block: %w", err)
}
// Load encoder conv_norm_out and conv_out
m.EncoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
if err := safetensors.LoadModule(m.EncoderNormOut, weights, "encoder.conv_norm_out"); err != nil {
return fmt.Errorf("encoder.conv_norm_out: %w", err)
}
m.EncoderConvOut = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.EncoderConvOut, weights, "encoder.conv_out"); err != nil {
return fmt.Errorf("encoder.conv_out: %w", err)
}
// Load quant_conv (for encoding)
if cfg.UseQuantConv {
m.QuantConv = &Conv2D{Stride: 1, Padding: 0}
if err := safetensors.LoadModule(m.QuantConv, weights, "quant_conv"); err != nil {
return fmt.Errorf("quant_conv: %w", err)
}
}
return nil
}
// loadDownEncoderBlock2D loads a down encoder block.
func loadDownEncoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasDownsample bool) (*DownEncoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var downsample *Conv2D
if hasDownsample {
downsample = &Conv2D{Stride: 2, Padding: 0}
if err := safetensors.LoadModule(downsample, weights, prefix+".downsamplers.0.conv"); err != nil {
return nil, err
}
}
return &DownEncoderBlock2D{
ResnetBlocks: resnets,
Downsample: downsample,
}, nil
}
// EncodeImage encodes an image to normalized latents.
// image: [B, 3, H, W] image tensor in [-1, 1]
// Returns: [B, L, C*4] patchified normalized latents
func (vae *AutoencoderKLFlux2) EncodeImage(image *mlx.Array) *mlx.Array {
// Convert NCHW -> NHWC
x := mlx.Transpose(image, 0, 2, 3, 1)
// Encoder
h := vae.EncoderConvIn.Forward(x)
for _, downBlock := range vae.EncoderDown {
h = downBlock.Forward(h)
}
h = vae.EncoderMid.Forward(h)
h = vae.EncoderNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.EncoderConvOut.Forward(h)
// Quant conv outputs [B, H, W, 2*latent_channels] (mean + logvar)
if vae.QuantConv != nil {
h = vae.QuantConv.Forward(h)
}
// Take only the mean (first latent_channels) - deterministic encoding
// h is [B, H, W, 64] -> take first 32 channels for mean
shape := h.Shape()
latentChannels := vae.Config.LatentChannels // 32
h = mlx.Slice(h, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], latentChannels})
// Convert NHWC -> NCHW for patchifying
h = mlx.Transpose(h, 0, 3, 1, 2)
// Patchify: [B, C, H, W] -> [B, L, C*4]
h = vae.Patchify(h)
// Apply BatchNorm on patchified latents [B, L, 128]
// The BatchNorm has 128 channels matching the patchified dimension
h = vae.normalizePatchified(h)
return h
}
// normalizePatchified applies batch normalization to patchified latents.
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
// Output: [B, L, 128] normalized
func (vae *AutoencoderKLFlux2) normalizePatchified(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[2] // 128
// Reshape stats for broadcasting [1, 1, C]
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
// Normalize: (x - mean) / sqrt(var + eps)
xNorm := mlx.Sub(x, mean)
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
// Scale and shift (only if affine=True)
if vae.LatentBN.Weight != nil {
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if vae.LatentBN.Bias != nil {
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}

View File

@@ -0,0 +1,390 @@
//go:build mlx
// Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models.
package qwen3
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Config holds Qwen3 text encoder configuration
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
}
// Attention implements Qwen3 attention with QK norms
type Attention struct {
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
RopeTheta float32
}
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
freqsArr := make([]float32, half)
logTheta := float32(math.Log(float64(theta)))
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
posArr := make([]float32, seqLen)
for i := int32(0); i < seqLen; i++ {
posArr[i] = float32(i)
}
pos := mlx.NewArray(posArr, []int32{seqLen})
posExpanded := mlx.Reshape(pos, seqLen, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
args := mlx.Mul(posExpanded, freqsExpanded)
cosVals := mlx.Cos(args)
sinVals := mlx.Sin(args)
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
}
// Forward computes attention with causal masking and optional padding mask
func (attn *Attention) Forward(x *mlx.Array, mask *mlx.Array, maskMode string) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
q := attn.QProj.Forward(x)
k := attn.KProj.Forward(x)
v := attn.VProj.Forward(x)
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
q = attn.QNorm.Forward(q, 1e-6)
k = attn.KNorm.Forward(k, 1e-6)
q = applyRoPEQwen3(q, L, attn.RopeTheta)
k = applyRoPEQwen3(k, L, attn.RopeTheta)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
k = repeatKV(k, repeats)
v = repeatKV(v, repeats)
}
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, attn.Scale, maskMode, mask, nil)
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
out = attn.OProj.Forward(out)
return out
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
x = mlx.ExpandDims(x, 2)
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// MLP implements Qwen3 SwiGLU MLP
type MLP struct {
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
}
// Forward applies the MLP
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := m.GateProj.Forward(x)
gate = mlx.SiLU(gate)
up := m.UpProj.Forward(x)
h := mlx.Mul(gate, up)
return m.DownProj.Forward(h)
}
// Block represents a single Qwen3 transformer block
type Block struct {
Attention *Attention `weight:"self_attn"`
MLP *MLP `weight:"mlp"`
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the Qwen3 block
func (qb *Block) Forward(x *mlx.Array, eps float32, mask *mlx.Array, maskMode string) *mlx.Array {
h := qb.InputLayerNorm.Forward(x, eps)
attnOut := qb.Attention.Forward(h, mask, maskMode)
x = mlx.Add(x, attnOut)
h = qb.PostAttnLayerNorm.Forward(x, eps)
mlpOut := qb.MLP.Forward(h)
x = mlx.Add(x, mlpOut)
return x
}
// TextEncoder is the full Qwen3 encoder
type TextEncoder struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Block `weight:"model.layers"`
FinalNorm *nn.RMSNorm `weight:"model.norm"`
*Config
}
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Config
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
m.Layers = make([]*Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *TextEncoder) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *TextEncoder) initComputedFields() {
cfg := m.Config
m.FinalNorm.Eps = cfg.RMSNormEps
for _, block := range m.Layers {
// Attention
block.Attention.NHeads = cfg.NumAttentionHeads
block.Attention.NKVHeads = cfg.NumKeyValueHeads
block.Attention.HeadDim = cfg.HeadDim
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.Attention.RopeTheta = cfg.RopeTheta
block.Attention.QNorm.Eps = cfg.RMSNormEps
block.Attention.KNorm.Eps = cfg.RMSNormEps
// Block norms
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
}
}
// Forward encodes text tokens with provided attention mask (LxL) and mask mode.
func (te *TextEncoder) Forward(tokens *mlx.Array, attnMask *mlx.Array, maskMode string) *mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
for _, layer := range te.Layers {
h = layer.Forward(h, eps, attnMask, maskMode)
}
// Apply final RMS norm
h = te.FinalNorm.Forward(h, eps)
return h
}
// ForwardWithLayerOutputs encodes text tokens and returns hidden states from specified layers.
// This is used by Flux2 which needs embeddings from specific intermediate layers.
func (te *TextEncoder) ForwardWithLayerOutputs(tokens *mlx.Array, layerIndices []int, attnMask *mlx.Array, maskMode string) []*mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
outputs := make([]*mlx.Array, len(layerIndices))
layerSet := make(map[int]int)
for i, idx := range layerIndices {
layerSet[idx] = i
}
for i, layer := range te.Layers {
h = layer.Forward(h, eps, attnMask, maskMode)
if outIdx, ok := layerSet[i]; ok {
outputs[outIdx] = h
}
}
return outputs
}
// ApplyChatTemplate wraps prompt in Qwen3 chat format.
// If think is true, adds the <think></think> block after the assistant tag
// (matches tokenizer.apply_chat_template with enable_thinking=False in Python).
func ApplyChatTemplate(prompt string, think bool) string {
base := "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
if think {
return base + "<think>\n\n</think>\n\n"
}
return base
}
// EncodePrompt encodes a text prompt using the tokenizer and encoder.
// If think is true, includes the <think></think> block in the chat template.
func (te *TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int, think bool) (*mlx.Array, *mlx.Array) {
formattedPrompt := ApplyChatTemplate(prompt, think)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
maskData := make([]float32, maxLen)
for i := 0; i < len(tokens); i++ {
maskData[i] = 1.0
}
// Get PAD token (different from EOS for Qwen3)
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
paddedTokens := make([]int32, maxLen)
copy(paddedTokens, tokens)
for i := len(tokens); i < maxLen; i++ {
paddedTokens[i] = padToken
}
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
// Build combined causal + PAD mask [L, L]
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
L := int32(maxLen)
validLen := int32(len(tokens))
combinedMaskData := make([]float32, L*L)
negInf := float32(-1e9)
for i := int32(0); i < L; i++ {
for j := int32(0); j < L; j++ {
idx := i*L + j
if j <= i && j < validLen {
combinedMaskData[idx] = 0
} else {
combinedMaskData[idx] = negInf
}
}
}
maskMat := mlx.NewArray(combinedMaskData, []int32{L, L})
embeddings := te.Forward(tokensArr, maskMat, "")
return embeddings, maskArr
}
// EncodePromptWithLayers encodes a text prompt and returns embeddings from specified layers.
// Used by Flux2 which concatenates embeddings from multiple intermediate layers.
// If think is true, includes the <think></think> block in the chat template.
// Returns embeddings and padded sequence length.
func (te *TextEncoder) EncodePromptWithLayers(tok *tokenizer.Tokenizer, prompt string, maxLen int, layerIndices []int, think bool) (*mlx.Array, int32) {
formattedPrompt := ApplyChatTemplate(prompt, think)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
// Pad to maxLen
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
padded := make([]int32, maxLen)
copy(padded, tokens)
for i := len(tokens); i < maxLen; i++ {
padded[i] = padToken
}
tokensArr := mlx.NewArrayInt32(padded, []int32{1, int32(maxLen)})
// Build combined causal + PAD mask [L, L]
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
// This combines causal masking with PAD token masking
L := int32(maxLen)
validLen := int32(len(tokens))
maskData := make([]float32, L*L)
negInf := float32(-1e9)
for i := int32(0); i < L; i++ {
for j := int32(0); j < L; j++ {
idx := i*L + j
if j <= i && j < validLen {
maskData[idx] = 0 // allowed: causal OK and not PAD
} else {
maskData[idx] = negInf // blocked: future or PAD
}
}
}
maskMat := mlx.NewArray(maskData, []int32{L, L})
layerOutputs := te.ForwardWithLayerOutputs(tokensArr, layerIndices, maskMat, "")
// Concatenate layer outputs along the hidden dimension
// Each output is [B, L, hidden_dim], result is [B, L, num_layers * hidden_dim]
embeddings := mlx.Concatenate(layerOutputs, 2)
// Return embeddings and padded length
return embeddings, int32(maxLen)
}

View File

@@ -17,13 +17,13 @@ import (
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
// Layer caching (DeepCache/Learning-to-Cache speedup)
LayerCache bool // Enable layer caching (default: false)
@@ -31,9 +31,6 @@ type GenerateConfig struct {
CacheLayers int // Number of shallow layers to cache (default: 25)
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image diffusion model.
type Model struct {
ModelPath string
@@ -117,7 +114,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
@@ -129,7 +126,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,

View File

@@ -18,18 +18,15 @@ import (
// GenerateConfig holds all options for image editing.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
Width int32 // Output width (default: from input image)
Height int32 // Output height (default: from input image)
Steps int // Denoising steps (default: 50)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
Width int32 // Output width (default: from input image)
Height int32 // Output height (default: from input image)
Steps int // Denoising steps (default: 50)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image-Edit diffusion model.
type Model struct {
ModelPath string

View File

@@ -3,287 +3,17 @@
package zimage
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/imagegen/models/qwen3"
)
// Qwen3Config holds Qwen3 text encoder configuration
type Qwen3Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
}
// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
RopeTheta float32
}
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
freqsArr := make([]float32, half)
logTheta := float32(math.Log(float64(theta)))
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
posArr := make([]float32, seqLen)
for i := int32(0); i < seqLen; i++ {
posArr[i] = float32(i)
}
pos := mlx.NewArray(posArr, []int32{seqLen})
posExpanded := mlx.Reshape(pos, seqLen, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
args := mlx.Mul(posExpanded, freqsExpanded)
cosVals := mlx.Cos(args)
sinVals := mlx.Sin(args)
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
}
// Forward computes attention with causal masking
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
q := attn.QProj.Forward(x)
k := attn.KProj.Forward(x)
v := attn.VProj.Forward(x)
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
q = attn.QNorm.Forward(q, 1e-6)
k = attn.KNorm.Forward(k, 1e-6)
q = applyRoPEQwen3(q, L, attn.RopeTheta)
k = applyRoPEQwen3(k, L, attn.RopeTheta)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
k = repeatKV(k, repeats)
v = repeatKV(v, repeats)
}
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
out = attn.OProj.Forward(out)
return out
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
x = mlx.ExpandDims(x, 2)
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct {
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
}
// Forward applies the MLP
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
gate := m.GateProj.Forward(x)
gate = mlx.SiLU(gate)
up := m.UpProj.Forward(x)
h := mlx.Mul(gate, up)
return m.DownProj.Forward(h)
}
// Qwen3Block represents a single Qwen3 transformer block
type Qwen3Block struct {
Attention *Qwen3Attention `weight:"self_attn"`
MLP *Qwen3MLP `weight:"mlp"`
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the Qwen3 block
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
h := qb.InputLayerNorm.Forward(x, eps)
attnOut := qb.Attention.Forward(h)
x = mlx.Add(x, attnOut)
h = qb.PostAttnLayerNorm.Forward(x, eps)
mlpOut := qb.MLP.Forward(h)
x = mlx.Add(x, mlpOut)
return x
}
// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
type Qwen3TextEncoder struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Qwen3Block `weight:"model.layers"`
FinalNorm *nn.RMSNorm `weight:"model.norm"`
*Qwen3Config
}
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *Qwen3TextEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Qwen3Config
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Qwen3Config = &cfg
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Qwen3TextEncoder) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *Qwen3TextEncoder) initComputedFields() {
cfg := m.Qwen3Config
m.FinalNorm.Eps = cfg.RMSNormEps
for _, block := range m.Layers {
// Attention
block.Attention.NHeads = cfg.NumAttentionHeads
block.Attention.NKVHeads = cfg.NumKeyValueHeads
block.Attention.HeadDim = cfg.HeadDim
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.Attention.RopeTheta = cfg.RopeTheta
block.Attention.QNorm.Eps = cfg.RMSNormEps
block.Attention.KNorm.Eps = cfg.RMSNormEps
// Block norms
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
}
}
// Forward encodes text tokens
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
for _, layer := range te.Layers {
h = layer.Forward(h, eps)
}
// Apply final RMS norm
h = te.FinalNorm.Forward(h, eps)
return h
}
// Re-export types from shared qwen3 package for backwards compatibility
type (
Qwen3Config = qwen3.Config
Qwen3Attention = qwen3.Attention
Qwen3MLP = qwen3.MLP
Qwen3Block = qwen3.Block
Qwen3TextEncoder = qwen3.TextEncoder
)
// ApplyChatTemplate wraps prompt in Qwen3 chat format
func ApplyChatTemplate(prompt string) string {
return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
}
// EncodePrompt encodes a text prompt using the tokenizer and encoder
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
formattedPrompt := ApplyChatTemplate(prompt)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
maskData := make([]float32, maxLen)
for i := 0; i < len(tokens); i++ {
maskData[i] = 1.0
}
// Get PAD token (different from EOS for Qwen3)
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
paddedTokens := make([]int32, maxLen)
copy(paddedTokens, tokens)
for i := len(tokens); i < maxLen; i++ {
paddedTokens[i] = padToken
}
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
embeddings := te.Forward(tokensArr)
return embeddings, maskArr
}
var ApplyChatTemplate = qwen3.ApplyChatTemplate

View File

@@ -9,6 +9,7 @@ import (
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/vae"
)
// VAEConfig holds VAE decoder configuration
@@ -636,6 +637,9 @@ type VAEDecoder struct {
UpBlocks []*UpDecoderBlock2D
ConvNormOut *GroupNormLayer
ConvOut *Conv2D
// Tiling configuration (nil = no tiling)
Tiling *vae.TilingConfig
}
// Load loads the VAE decoder from ollama blob storage.
@@ -730,45 +734,60 @@ func (m *VAEDecoder) loadWeights(weights safetensors.WeightSource, cfg *VAEConfi
// Decode decodes latents to images.
// Input latents are in NCHW format, output is in NCHW format.
// Internally uses NHWC format (MLX native) for all operations.
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
// If Tiling is set, uses tiled decoding to reduce memory for large images.
func (v *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
// Scale latents
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
z := mlx.DivScalar(latents, v.Config.ScalingFactor)
z = mlx.AddScalar(z, v.Config.ShiftFactor)
// Convert NCHW -> NHWC for internal processing
z = mlx.Transpose(z, 0, 2, 3, 1)
h := vae.ConvIn.Forward(z)
// Use tiled decoding if enabled
if v.Tiling != nil {
mlx.Eval(z)
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
}
// Direct decode
h := v.decodeTile(z)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
// Convert NHWC -> NCHW for output
h = mlx.Transpose(h, 0, 3, 1, 2)
mlx.Eval(h)
return h
}
// decodeTile decodes a single latent tile to pixels.
// Input: [B, H, W, C] latent tile in NHWC format (already scaled)
// Output: [B, H*8, W*8, 3] pixel tile in NHWC format
func (v *VAEDecoder) decodeTile(z *mlx.Array) *mlx.Array {
h := v.ConvIn.Forward(z)
mlx.Eval(h)
prev := h
h = vae.MidBlock.Forward(h)
h = v.MidBlock.Forward(h)
prev.Free()
for _, upBlock := range vae.UpBlocks {
for _, upBlock := range v.UpBlocks {
prev = h
h = upBlock.Forward(h)
prev.Free()
}
prev = h
h = vae.ConvNormOut.Forward(h)
h = v.ConvNormOut.Forward(h)
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
prev.Free()
prev = h
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
h = v.ConvOut.Forward(h)
mlx.Eval(h)
prev.Free()
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
h = mlx.AddScalar(h, 0.5)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
// Convert NHWC -> NCHW for output
h = mlx.Transpose(h, 0, 3, 1, 2)
mlx.Eval(h)
return h
}

View File

@@ -12,19 +12,20 @@ import (
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/imagegen/vae"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
CapturePath string // GPU capture path (debug)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
@@ -34,9 +35,6 @@ type GenerateConfig struct {
FusedQKV bool // Enable fused QKV projection (default: false)
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Z-Image diffusion model.
type Model struct {
ModelName string
@@ -93,7 +91,7 @@ func (m *Model) Load(modelName string) error {
// Load text encoder
m.TextEncoder = &Qwen3TextEncoder{}
if err := m.TextEncoder.Load(manifest); err != nil {
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
@@ -139,7 +137,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
@@ -151,7 +149,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
@@ -179,9 +177,16 @@ func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*m
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
// GenerateImage implements runner.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// generate is the internal denoising pipeline.
@@ -222,9 +227,9 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
// Text encoding with padding to multiple of 32
var posEmb, negEmb *mlx.Array
{
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512)
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512, false)
if useCFG {
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512)
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512, false)
}
// Pad both to same length (multiple of 32)
@@ -448,7 +453,11 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
teaCache.Free()
}
// VAE decode
// VAE decode - enable tiling for larger images to reduce memory
// VAE attention is O(n²) on latent pixels, tiling helps significantly
if latentH > 64 || latentW > 64 {
m.VAEDecoder.Tiling = vae.DefaultTilingConfig()
}
decoded := m.VAEDecoder.Decode(latents)
latents.Free()

View File

@@ -19,6 +19,7 @@ import (
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
@@ -40,10 +41,15 @@ type Response struct {
Total int `json:"total,omitempty"`
}
// ImageModel is the interface for image generation models
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model *zimage.Model
model ImageModel
modelName string
}
@@ -80,10 +86,25 @@ func Execute(args []string) error {
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Load model
model := &zimage.Model{}
if err := model.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
// Detect model type and load appropriate model
modelType := imagegen.DetectModelType(*modelName)
slog.Info("detected model type", "type", modelType)
var model ImageModel
switch modelType {
case "Flux2KleinPipeline":
m := &flux2.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
model = m
default:
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
m := &zimage.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
model = m
}
server := &Server{
@@ -159,26 +180,19 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Generate image
// Generate image using the common interface
ctx := r.Context()
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: req.Seed,
Progress: func(step, total int) {
resp := Response{
Step: step,
Total: total,
Done: false,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
},
})
enc := json.NewEncoder(w)
// Progress callback streams step updates
progress := func(step, total int) {
resp := Response{Step: step, Total: total}
enc.Encode(resp)
w.Write([]byte("\n"))
flusher.Flush()
}
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
if err != nil {
// Don't send error for cancellation

View File

@@ -17,6 +17,24 @@ type WeightSource interface {
GetTensor(name string) (*mlx.Array, error)
ListTensors() []string
HasTensor(name string) bool
Quantization() string // Returns "FP4", "FP8", or ""
}
// quantizationParams returns groupSize, bits, mode for a quantization type.
// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "FP4":
return 32, 4, "affine"
default:
return 32, 8, "affine" // FP8 or unknown
}
}
// Transformer allows structs to transform weight arrays before assignment.
// Implement this to apply operations like transpose during loading.
type Transformer interface {
Transform(field string, arr *mlx.Array) *mlx.Array
}
// LoadModule loads weights into a struct using reflection and struct tags.
@@ -136,6 +154,10 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
}
continue
}
// Transform before assigning if parent implements Transformer
if t, ok := v.Addr().Interface().(Transformer); ok {
arr = t.Transform(field.Name, arr)
}
fieldVal.Set(reflect.ValueOf(arr))
continue
}
@@ -223,19 +245,21 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
qbiases, _ = weights.GetTensor(qbiasPath)
}
groupSize, bits, mode := quantizationParams(weights.Quantization())
if mlx.MetalIsAvailable() {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: 32,
Bits: 8,
Mode: "affine",
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}, nil
}
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
dequantized := mlx.Dequantize(weight, scales, qbiases, groupSize, bits, mode)
return nn.NewLinear(dequantized, bias), nil
}

View File

@@ -298,6 +298,11 @@ func (mw *ModelWeights) HasTensor(name string) bool {
return ok
}
// Quantization returns empty string for directory-based weights (not quantized).
func (mw *ModelWeights) Quantization() string {
return ""
}
// ReleaseAll releases all cached native file handles.
func (mw *ModelWeights) ReleaseAll() {
for path, native := range mw.nativeCache {

View File

@@ -510,7 +510,11 @@ func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) {
t.vocab.Merges[merge] = i
}
// Add special tokens to vocabulary
// Add all added_tokens to vocabulary and special tokens map.
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
// they bypass BPE and get their own token ID. The "special" flag just indicates
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
// to treat all added_tokens as special to match HuggingFace behavior.
for _, tok := range raw.AddedTokens {
if int(tok.ID) >= len(t.vocab.Values) {
newValues := make([]string, tok.ID+1)
@@ -518,9 +522,7 @@ func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) {
t.vocab.Values = newValues
}
t.vocab.Values[tok.ID] = tok.Content
if tok.Special {
t.specialTokens[tok.Content] = tok.ID
}
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
}
// Load special token configuration from companion files

215
x/imagegen/vae/tiling.go Normal file
View File

@@ -0,0 +1,215 @@
//go:build mlx
// Package vae provides shared utilities for VAE (Variational Autoencoder) operations.
package vae
import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TilingConfig holds configuration for tiled VAE decoding.
// This is a general technique to reduce memory usage when decoding large latents.
type TilingConfig struct {
TileSize int32 // Tile size in latent space (e.g., 64 latent → 512 pixels for 8x VAE)
Overlap int32 // Overlap in latent space (e.g., 16 latent = 25% of 64)
}
// DefaultTilingConfig returns reasonable defaults matching diffusers.
// tile_latent_min_size=64, tile_overlap_factor=0.25
func DefaultTilingConfig() *TilingConfig {
return &TilingConfig{
TileSize: 64, // 64 latent pixels
Overlap: 16, // 25% overlap
}
}
// decodedTile holds a decoded tile's pixel data and dimensions
type decodedTile struct {
data []float32
height int32
width int32
}
// DecodeTiled decodes latents using tiled processing with overlap blending.
// This reduces memory usage for large images by processing in overlapping tiles.
//
// Parameters:
// - latents: [1, H, W, C] latent tensor in NHWC format
// - cfg: tiling configuration (tile size and overlap)
// - decoder: function to decode a single tile [1, H, W, C] -> [1, H*scale, W*scale, 3]
//
// Returns: [1, 3, H*scale, W*scale] decoded image in NCHW format
func DecodeTiled(latents *mlx.Array, cfg *TilingConfig, decoder func(*mlx.Array) *mlx.Array) *mlx.Array {
shape := latents.Shape()
H := shape[1] // latent height
W := shape[2] // latent width
C := shape[3]
tileLatentSize := cfg.TileSize
overlapLatent := cfg.Overlap
// If image is small enough, just decode normally
if H <= tileLatentSize && W <= tileLatentSize {
decoded := decoder(latents)
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
decoded = mlx.Transpose(decoded, 0, 3, 1, 2) // NHWC -> NCHW
return decoded
}
// Calculate tiling parameters (matching diffusers)
overlapSize := tileLatentSize - overlapLatent // stride in latent space
// Blend extent in pixel space (assumes 8x upscale, adjust if needed)
// For other scale factors, this could be made configurable
tileSampleSize := tileLatentSize * 8 // tile size in pixels after 8x upscale
blendExtent := overlapLatent * 8 // blend region in pixels
rowLimit := tileSampleSize - blendExtent // non-overlapping region per tile
// Phase 1: Decode all tiles and store in 2D grid
var rows [][]decodedTile
for i := int32(0); i < H; i += overlapSize {
var row []decodedTile
for j := int32(0); j < W; j += overlapSize {
// Extract tile (may be smaller at edges)
i2 := min(i+tileLatentSize, H)
j2 := min(j+tileLatentSize, W)
tile := mlx.Slice(latents, []int32{0, i, j, 0}, []int32{1, i2, j2, C})
decoded := decoder(tile)
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
mlx.Eval(decoded)
decodedShape := decoded.Shape()
tileH := decodedShape[1]
tileW := decodedShape[2]
tileData := decoded.Data()
decoded.Free()
row = append(row, decodedTile{data: tileData, height: tileH, width: tileW})
}
rows = append(rows, row)
}
// Phase 2: Blend adjacent tiles (modifies in place)
for i := range rows {
for j := range rows[i] {
tile := &rows[i][j]
// Blend with tile above
if i > 0 {
above := &rows[i-1][j]
blendV(above, tile, blendExtent)
}
// Blend with tile to the left
if j > 0 {
left := &rows[i][j-1]
blendH(left, tile, blendExtent)
}
}
}
// Phase 3: Calculate crop dimensions for each tile
colWidths := make([]int32, len(rows[0]))
for j := range rows[0] {
keepW := rowLimit
if int32(j+1)*overlapSize >= W {
keepW = rows[0][j].width
}
colWidths[j] = keepW
}
rowHeights := make([]int32, len(rows))
for i := range rows {
keepH := rowLimit
if int32(i+1)*overlapSize >= H {
keepH = rows[i][0].height
}
rowHeights[i] = keepH
}
// Calculate total dimensions
var totalW, totalH int32
for _, w := range colWidths {
totalW += w
}
for _, h := range rowHeights {
totalH += h
}
// Phase 4: Assemble final image by interleaving tiles row-by-row
finalData := make([]float32, totalH*totalW*3)
dstY := int32(0)
for i, row := range rows {
keepH := rowHeights[i]
for y := int32(0); y < keepH; y++ {
dstX := int32(0)
for j, tile := range row {
keepW := colWidths[j]
for x := int32(0); x < keepW; x++ {
for c := int32(0); c < 3; c++ {
srcIdx := (y*tile.width + x) * 3 + c
dstIdx := ((dstY + y) * totalW + (dstX + x)) * 3 + c
finalData[dstIdx] = tile.data[srcIdx]
}
}
dstX += keepW
}
}
dstY += keepH
}
// Create mlx array [1, H, W, 3] then transpose to NCHW [1, 3, H, W]
result := mlx.NewArray(finalData, []int32{1, totalH, totalW, 3})
result = mlx.Transpose(result, 0, 3, 1, 2)
result = mlx.ClipScalar(result, 0.0, 1.0, true, true)
return result
}
// blendV blends the bottom of 'above' tile into top of 'current' tile (vertical blend)
// Matches diffusers blend_v formula
func blendV(above, current *decodedTile, blendExtent int32) {
blend := min(blendExtent, min(above.height, current.height))
if blend <= 0 {
return
}
w := min(above.width, current.width)
for y := int32(0); y < blend; y++ {
alpha := float32(y) / float32(blend)
for x := int32(0); x < w; x++ {
for c := int32(0); c < 3; c++ {
aboveIdx := ((above.height - blend + y) * above.width + x) * 3 + c
currIdx := (y * current.width + x) * 3 + c
current.data[currIdx] = above.data[aboveIdx]*(1-alpha) + current.data[currIdx]*alpha
}
}
}
}
// blendH blends the right of 'left' tile into left of 'current' tile (horizontal blend)
// Matches diffusers blend_h formula
func blendH(left, current *decodedTile, blendExtent int32) {
blend := min(blendExtent, min(left.width, current.width))
if blend <= 0 {
return
}
h := min(left.height, current.height)
for y := int32(0); y < h; y++ {
for x := int32(0); x < blend; x++ {
alpha := float32(x) / float32(blend)
for c := int32(0); c < 3; c++ {
leftIdx := (y * left.width + (left.width - blend + x)) * 3 + c
currIdx := (y * current.width + x) * 3 + c
current.data[currIdx] = left.data[leftIdx]*(1-alpha) + current.data[currIdx]*alpha
}
}
}
}

View File

@@ -106,6 +106,21 @@ func (mw *ManifestWeights) HasTensor(name string) bool {
return ok
}
// Quantization returns the model's quantization type from model_index.json.
// Returns empty string if not quantized or unknown.
func (mw *ManifestWeights) Quantization() string {
if mw.manifest == nil {
return ""
}
var index struct {
Quantization string `json:"quantization"`
}
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil {
return ""
}
return index.Quantization
}
// ReleaseAll frees all native handles and clears the tensor cache.
func (mw *ManifestWeights) ReleaseAll() {
for _, sf := range mw.nativeCache {

View File

@@ -9,7 +9,8 @@ import (
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
// modelConfig represents the HuggingFace config.json structure
@@ -35,22 +36,22 @@ type modelConfig struct {
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
// It reads the config.json layer and returns a map compatible with GGML's KV format.
func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
manifest, err := imagegen.LoadManifest(modelName)
func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
var config modelConfig
if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
if err := mf.ReadConfigJSON("config.json", &config); err != nil {
return nil, fmt.Errorf("failed to read config.json: %w", err)
}
// Calculate total tensor bytes from manifest layers
var totalBytes int64
var tensorCount int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
for _, layer := range mf.Layers {
if layer.MediaType == manifest.MediaTypeImageTensor {
totalBytes += layer.Size
tensorCount++
}
@@ -151,27 +152,30 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
manifest, err := imagegen.LoadManifest(modelName)
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
return getTensorInfoFromManifest(manifest)
return getTensorInfoFromManifest(mf)
}
// getTensorInfoFromManifest extracts tensor info from a manifest.
// This is separated for testability.
func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
var tensors []api.Tensor
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType != "application/vnd.ollama.image.tensor" {
for _, layer := range mf.Layers {
if layer.MediaType != manifest.MediaTypeImageTensor {
continue
}
// Read the safetensors header from the blob
blobPath := manifest.BlobPath(layer.Digest)
blobPath, err := manifest.GetBlobsPath(layer.Digest)
if err != nil {
continue
}
info, err := readSafetensorsHeader(blobPath)
if err != nil {
// Skip tensors we can't read
@@ -197,15 +201,15 @@ func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor,
// GetSafetensorsDtype returns the quantization type for a safetensors model.
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
// Otherwise returns the torch_dtype from config.json.
func GetSafetensorsDtype(modelName string) (string, error) {
manifest, err := imagegen.LoadManifest(modelName)
func GetSafetensorsDtype(name model.Name) (string, error) {
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return "", fmt.Errorf("failed to load manifest: %w", err)
}
// Check if model is quantized by looking for _scale tensors
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
for _, layer := range mf.Layers {
if layer.MediaType == manifest.MediaTypeImageTensor {
if strings.HasSuffix(layer.Name, "_scale") {
// Model is quantized - return FP8 (affine quantization)
return "FP8", nil
@@ -217,7 +221,7 @@ func GetSafetensorsDtype(modelName string) (string, error) {
var cfg struct {
TorchDtype string `json:"torch_dtype"`
}
if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
if err := mf.ReadConfigJSON("config.json", &cfg); err != nil {
return "", fmt.Errorf("failed to read config.json: %w", err)
}

View File

@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/manifest"
)
func TestBuildModelInfo(t *testing.T) {
@@ -451,8 +451,14 @@ func TestParseSafetensorsHeader_Errors(t *testing.T) {
}
func TestGetTensorInfoFromManifest(t *testing.T) {
// Create a temp directory for blobs
// Create a temp directory for blobs and set OLLAMA_MODELS
tempDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", tempDir)
blobDir := filepath.Join(tempDir, "blobs")
if err := os.MkdirAll(blobDir, 0o755); err != nil {
t.Fatalf("failed to create blobs dir: %v", err)
}
// Create test tensor blobs
tensors := []struct {
@@ -463,26 +469,26 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
}{
{
name: "model.embed_tokens.weight",
digest: "sha256:abc123",
digest: "sha256:abc123abc123abc123abc123abc123abc123abc123abc123abc123abc123abc0",
dtype: "BF16",
shape: []int64{262144, 2560},
},
{
name: "model.layers.0.self_attn.q_proj.weight",
digest: "sha256:def456",
digest: "sha256:def456def456def456def456def456def456def456def456def456def456def0",
dtype: "BF16",
shape: []int64{2560, 2560},
},
{
name: "model.norm.weight",
digest: "sha256:ghi789",
digest: "sha256:789789789789789789789789789789789789789789789789789789789789abc0",
dtype: "F32",
shape: []int64{2560},
},
}
// Create blob files
var layers []imagegen.ManifestLayer
var layers []manifest.Layer
for _, tensor := range tensors {
// Create safetensors blob
header := map[string]any{
@@ -498,15 +504,17 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
// Write blob file
blobName := "sha256-" + tensor.digest[7:]
blobPath := filepath.Join(tempDir, blobName)
// Write blob file using the digest format expected by GetBlobsPath
blobPath, err := manifest.GetBlobsPath(tensor.digest)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob: %v", err)
}
layers = append(layers, imagegen.ManifestLayer{
MediaType: "application/vnd.ollama.image.tensor",
layers = append(layers, manifest.Layer{
MediaType: manifest.MediaTypeImageTensor,
Digest: tensor.digest,
Size: int64(buf.Len() + 1000), // header + fake data
Name: tensor.name,
@@ -514,21 +522,20 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
}
// Add a non-tensor layer (should be skipped)
layers = append(layers, imagegen.ManifestLayer{
layers = append(layers, manifest.Layer{
MediaType: "application/vnd.ollama.image.json",
Digest: "sha256:config",
Digest: "sha256:0000000000000000000000000000000000000000000000000000000000000000",
Size: 100,
Name: "config.json",
})
manifest := &imagegen.ModelManifest{
Manifest: &imagegen.Manifest{
Layers: layers,
},
BlobDir: tempDir,
mf := &manifest.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Layers: layers,
}
result, err := getTensorInfoFromManifest(manifest)
result, err := getTensorInfoFromManifest(mf)
if err != nil {
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
}