mirror of
https://github.com/ollama/ollama.git
synced 2026-01-20 21:40:54 -05:00
Compare commits
1 Commits
pdevine/ma
...
fix-mlx-qu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
582d93ab22 |
@@ -749,7 +749,7 @@ type ShowResponse struct {
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
|
||||
@@ -899,11 +899,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
for _, arg := range args {
|
||||
// Unload the model if it's running before deletion
|
||||
if err := loadOrUnloadModel(cmd, &runOptions{
|
||||
Model: arg,
|
||||
Model: args[0],
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}); err != nil {
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", arg)
|
||||
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -311,10 +311,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &deepseekocr{}
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
case "Glm4MoeLiteForCausalLM":
|
||||
conv = &glm4MoeLiteModel{}
|
||||
case "Lfm2ForCausalLM":
|
||||
conv = &lfm2Model{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type glm4MoeLiteModel struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
|
||||
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
|
||||
KVLoraRank uint32 `json:"kv_lora_rank"`
|
||||
QLoraRank uint32 `json:"q_lora_rank"`
|
||||
VHeadDim uint32 `json:"v_head_dim"`
|
||||
|
||||
ExpertCount uint32 `json:"n_routed_experts"`
|
||||
ExpertSharedCount uint32 `json:"n_shared_experts"`
|
||||
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
|
||||
ExpertWeightsNorm bool `json:"norm_topk_prob"`
|
||||
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
|
||||
|
||||
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "glm4moelite"
|
||||
kv["general.type"] = "model"
|
||||
kv["glm4moelite.block_count"] = p.HiddenLayers
|
||||
|
||||
numHeads := p.NumAttentionHeads
|
||||
numKVHeads := p.NumKeyValueHeads
|
||||
|
||||
kv["glm4moelite.attention.head_count"] = numHeads
|
||||
kv["glm4moelite.attention.head_count_kv"] = numKVHeads
|
||||
kv["glm4moelite.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
|
||||
kv["glm4moelite.attention.kv_lora_rank"] = p.KVLoraRank
|
||||
kv["glm4moelite.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
kv["glm4moelite.attention.q_lora_rank"] = p.QLoraRank
|
||||
kv["glm4moelite.attention.value_length"] = p.VHeadDim
|
||||
kv["glm4moelite.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["glm4moelite.embedding_length"] = p.HiddenSize
|
||||
kv["glm4moelite.expert_count"] = p.ExpertCount
|
||||
kv["glm4moelite.expert_feed_forward_length"] = p.ExpertIntermediateSize
|
||||
kv["glm4moelite.expert_shared_count"] = p.ExpertSharedCount
|
||||
|
||||
kv["glm4moelite.expert_gating_func"] = uint32(2)
|
||||
kv["glm4moelite.expert_used_count"] = p.ExpertUsedCount
|
||||
kv["glm4moelite.expert_weights_norm"] = p.ExpertWeightsNorm
|
||||
kv["glm4moelite.expert_weights_scale"] = p.ExpertWeightsScale
|
||||
kv["glm4moelite.feed_forward_length"] = p.IntermediateSize
|
||||
kv["glm4moelite.leading_dense_block_count"] = p.LeadingDenseBlockCount
|
||||
|
||||
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
|
||||
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
|
||||
|
||||
kv["tokenizer.ggml.pre"] = "glm4"
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
|
||||
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
|
||||
"self_attn.kv_b_proj", "attn_kv_b",
|
||||
"self_attn.q_a_proj", "attn_q_a",
|
||||
"self_attn.q_a_layernorm", "attn_q_a_norm",
|
||||
"self_attn.q_b_proj", "attn_q_b",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, p.HiddenLayers*3)
|
||||
for i := range p.HiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
skipLayer := func(n string, minValue uint32) bool {
|
||||
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(n)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
blkNum, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return uint32(blkNum) >= minValue
|
||||
}
|
||||
|
||||
out, s = mergeTensors(s, merges...)
|
||||
for _, t := range s {
|
||||
// skip any additional layers (such as the Multi-Token Prediction layer)
|
||||
if skipLayer(t.Name(), p.HiddenLayers) {
|
||||
slog.Debug("skipping layer", "name", t.Name())
|
||||
continue
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type lfm2Model struct {
|
||||
ModelParameters
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NormEps float32 `json:"norm_eps"`
|
||||
ConvLCache uint32 `json:"conv_L_cache"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
TieEmbedding bool `json:"tie_embedding"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*lfm2Model)(nil)
|
||||
|
||||
func (p *lfm2Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "lfm2"
|
||||
kv["lfm2.vocab_size"] = p.VocabSize
|
||||
kv["lfm2.block_count"] = p.NumHiddenLayers
|
||||
kv["lfm2.embedding_length"] = p.HiddenSize
|
||||
kv["lfm2.feed_forward_length"] = p.IntermediateSize
|
||||
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
// Build per-layer KV head count array based on layer_types
|
||||
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
|
||||
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
|
||||
for i := range p.NumHiddenLayers {
|
||||
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
|
||||
kvHeadCounts[i] = p.NumKeyValueHeads
|
||||
}
|
||||
}
|
||||
|
||||
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
|
||||
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
|
||||
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
|
||||
kv["lfm2.rope.freq_base"] = p.RopeTheta
|
||||
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
shape := t.Shape()
|
||||
|
||||
// Squeeze conv weights: [D, 1, K] -> [D, K]
|
||||
if strings.HasSuffix(t.Name(), "shortconv.conv.weight") {
|
||||
if len(shape) == 3 && shape[1] == 1 {
|
||||
shape = []uint64{shape[0], shape[2]}
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: slices.Clone(shape),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *lfm2Model) Replacements() []string {
|
||||
return []string{
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.embedding_norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
"operator_norm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.out_proj", "attn_output",
|
||||
"self_attn.q_layernorm", "attn_q_norm",
|
||||
"self_attn.k_layernorm", "attn_k_norm",
|
||||
"conv.conv", "shortconv.conv",
|
||||
"conv.in_proj", "shortconv.in_proj",
|
||||
"conv.out_proj", "shortconv.out_proj",
|
||||
"feed_forward.w1", "ffn_gate",
|
||||
"feed_forward.w2", "ffn_down",
|
||||
"feed_forward.w3", "ffn_up",
|
||||
"ffn_norm", "ffn_norm",
|
||||
}
|
||||
}
|
||||
@@ -40,7 +40,6 @@ const (
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
|
||||
@@ -269,8 +269,6 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
"lfm2",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -858,9 +856,7 @@ func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
"gemma3",
|
||||
"glm4moelite",
|
||||
"gptoss", "gpt-oss",
|
||||
"lfm2",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestImageGeneration(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 8)
|
||||
|
||||
type testCase struct {
|
||||
imageGenModel string
|
||||
visionModel string
|
||||
prompt string
|
||||
expectedWords []string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
imageGenModel: "jmorgan/z-image-turbo",
|
||||
visionModel: "llama3.2-vision",
|
||||
prompt: "A cartoon style llama flying like a superhero through the air with clouds in the background",
|
||||
expectedWords: []string{"llama", "flying", "cartoon", "cloud", "sky", "superhero", "air", "animal", "camelid"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%s->%s", tc.imageGenModel, tc.visionModel), func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Pull both models
|
||||
if err := PullIfMissing(ctx, client, tc.imageGenModel); err != nil {
|
||||
t.Fatalf("failed to pull image gen model: %v", err)
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, tc.visionModel); err != nil {
|
||||
t.Fatalf("failed to pull vision model: %v", err)
|
||||
}
|
||||
|
||||
// Generate the image
|
||||
t.Logf("Generating image with prompt: %s", tc.prompt)
|
||||
imageBase64, err := generateImage(ctx, client, tc.imageGenModel, tc.prompt)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "image generation not available") {
|
||||
t.Skip("Target system does not support image generation")
|
||||
} else if strings.Contains(err.Error(), "executable file not found in") { // Windows pattern, not yet supported
|
||||
t.Skip("Windows does not support image generation yet")
|
||||
} else if strings.Contains(err.Error(), "CUDA driver version is insufficient") {
|
||||
t.Skip("Driver is too old")
|
||||
} else if strings.Contains(err.Error(), "insufficient memory for image generation") {
|
||||
t.Skip("insufficient memory for image generation")
|
||||
} else if strings.Contains(err.Error(), "error while loading shared libraries: libcuda.so.1") { // AMD GPU or CPU
|
||||
t.Skip("CUDA GPU is not available")
|
||||
} else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") {
|
||||
// most likely linux arm - not supported yet
|
||||
t.Skip("unsupported architecture")
|
||||
}
|
||||
t.Fatalf("failed to generate image: %v", err)
|
||||
}
|
||||
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode image: %v", err)
|
||||
}
|
||||
t.Logf("Generated image: %d bytes", len(imageData))
|
||||
|
||||
// Preload vision model and check GPU loading
|
||||
err = client.Generate(ctx, &api.GenerateRequest{Model: tc.visionModel}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load vision model: %v", err)
|
||||
}
|
||||
|
||||
// Use vision model to describe the image
|
||||
chatReq := api.ChatRequest{
|
||||
Model: tc.visionModel,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Describe this image in detail. What is shown? What style is it? What is the main subject doing?",
|
||||
Images: []api.ImageData{imageData},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
// Verify the vision model's response contains expected keywords
|
||||
response := DoChat(ctx, t, client, chatReq, tc.expectedWords, 240*time.Second, 30*time.Second)
|
||||
if response != nil {
|
||||
t.Logf("Vision model response: %s", response.Content)
|
||||
|
||||
// Additional detailed check for keywords
|
||||
content := strings.ToLower(response.Content)
|
||||
foundWords := []string{}
|
||||
missingWords := []string{}
|
||||
for _, word := range tc.expectedWords {
|
||||
if strings.Contains(content, word) {
|
||||
foundWords = append(foundWords, word)
|
||||
} else {
|
||||
missingWords = append(missingWords, word)
|
||||
}
|
||||
}
|
||||
t.Logf("Found keywords: %v", foundWords)
|
||||
if len(missingWords) > 0 {
|
||||
t.Logf("Missing keywords (at least one was found so test passed): %v", missingWords)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateImage calls the Ollama API to generate an image and returns the base64 image data
|
||||
func generateImage(ctx context.Context, client *api.Client, model, prompt string) (string, error) {
|
||||
var imageBase64 string
|
||||
|
||||
err := client.Generate(ctx, &api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
}, func(resp api.GenerateResponse) error {
|
||||
if resp.Image != "" {
|
||||
imageBase64 = resp.Image
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate image: %w", err)
|
||||
}
|
||||
|
||||
if imageBase64 == "" {
|
||||
return "", fmt.Errorf("no image data in response")
|
||||
}
|
||||
|
||||
return imageBase64, nil
|
||||
}
|
||||
@@ -38,7 +38,6 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"lfm2.5-thinking",
|
||||
"ministral-3",
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
@@ -144,7 +143,6 @@ var (
|
||||
"granite3.3",
|
||||
"hermes3",
|
||||
"internlm2",
|
||||
"lfm2.5-thinking",
|
||||
"llama-guard3",
|
||||
"llama-pro",
|
||||
"llama2-chinese",
|
||||
@@ -265,7 +263,6 @@ var (
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
libraryToolsModels = []string{
|
||||
"lfm2.5-thinking",
|
||||
"qwen3-vl",
|
||||
"gpt-oss:20b",
|
||||
"gpt-oss:120b",
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
|
||||
func GetManifestPath() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// GetManifestPathForName returns the path to the manifest file for a specific model name.
|
||||
func GetManifestPathForName(n model.Name) (string, error) {
|
||||
if !n.IsValid() {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(manifests, n.Filepath()), nil
|
||||
}
|
||||
|
||||
func GetBlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PruneDirectory removes empty directories recursively.
|
||||
func PruneDirectory(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
entries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -162,7 +162,6 @@ type Tensor interface {
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
|
||||
SSMConv(ctx Context, kernel Tensor) Tensor
|
||||
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
|
||||
@@ -1641,13 +1641,6 @@ func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2,
|
||||
return tt
|
||||
}
|
||||
|
||||
func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
normTopKProb bool
|
||||
routedScalingFactor float32
|
||||
|
||||
kvLoraRank,
|
||||
qkNopeHeadDim,
|
||||
qkRopeHeadDim,
|
||||
kqNopeHeadDim,
|
||||
qkHeadDim int
|
||||
qLoraRank int
|
||||
vHeadDim int
|
||||
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads int
|
||||
|
||||
eps,
|
||||
ropeBase float32
|
||||
kqScale float64
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1.0)
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Q *nn.Linear `gguf:"attn_q"`
|
||||
|
||||
QA *nn.Linear `gguf:"attn_q_a"`
|
||||
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
|
||||
QB *nn.Linear `gguf:"attn_q_b"`
|
||||
|
||||
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
if opts.qLoraRank == 0 {
|
||||
query = attn.Q.Forward(ctx, hiddenStates)
|
||||
} else {
|
||||
query = attn.QA.Forward(ctx, hiddenStates)
|
||||
query = attn.QANorm.Forward(ctx, query, opts.eps)
|
||||
query = attn.QB.Forward(ctx, query)
|
||||
}
|
||||
|
||||
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
|
||||
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
|
||||
|
||||
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
|
||||
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
|
||||
kRot := compressedKV.View(ctx,
|
||||
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
|
||||
compressedKV.Stride(1), 1,
|
||||
compressedKV.Stride(1), compressedKV.Dim(1),
|
||||
)
|
||||
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP interface {
|
||||
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
SharedExpert *dense `gguf:",suf:_shexp"`
|
||||
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||
}
|
||||
|
||||
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
hiddenStates = hiddenStates.SILU(ctx, upStates)
|
||||
|
||||
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
experts = experts.Mul(ctx, topKWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
return nextStates
|
||||
}
|
||||
|
||||
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
|
||||
if moe.ExpProbsBias != nil {
|
||||
scores = scores.Add(ctx, moe.ExpProbsBias)
|
||||
}
|
||||
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
|
||||
return topKIndices
|
||||
}
|
||||
|
||||
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
residuals := hiddenStates
|
||||
|
||||
routerLogits := moe.Router.Forward(ctx, hiddenStates)
|
||||
scores := routerLogits.Sigmoid(ctx)
|
||||
topKIndices := moe.topKIndices(ctx, scores, opts)
|
||||
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
|
||||
|
||||
if opts.normTopKProb {
|
||||
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
|
||||
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
}
|
||||
|
||||
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
|
||||
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
|
||||
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type dense struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Attention *Attention
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP MLP
|
||||
}
|
||||
|
||||
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
residual = hiddenStates
|
||||
|
||||
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
layers := make([]Layer, c.Uint("block_count"))
|
||||
|
||||
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
|
||||
for i := range layers {
|
||||
if i < firstDenseLayerIndex {
|
||||
layers[i].MLP = &dense{}
|
||||
} else {
|
||||
layers[i].MLP = &sparse{}
|
||||
}
|
||||
}
|
||||
|
||||
keyLength := int(c.Uint("attention.key_length"))
|
||||
valueLength := int(c.Uint("attention.value_length"))
|
||||
|
||||
kqScale := 1.0 / math.Sqrt(float64(keyLength))
|
||||
|
||||
var pre []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
case "glm4":
|
||||
pre = []string{
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
}
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
pre...,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("expert_weights_norm", true),
|
||||
|
||||
qLoraRank: int(c.Uint("attention.q_lora_rank")),
|
||||
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
|
||||
qkHeadDim: keyLength,
|
||||
vHeadDim: valueLength,
|
||||
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
|
||||
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
|
||||
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||
|
||||
kqScale: kqScale,
|
||||
},
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("glm4moelite", New)
|
||||
}
|
||||
@@ -1,410 +0,0 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var _ kvcache.Cache = (*HybridCache)(nil)
|
||||
|
||||
// HybridCache stores:
|
||||
// - a standard causal KV cache for attention layers
|
||||
// - a per-sequence recurrent conv state for shortconv layers
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
|
||||
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
|
||||
type HybridCache struct {
|
||||
kv *kvcache.Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
hiddenSize int
|
||||
dConv int
|
||||
|
||||
// slot mapping for recurrent state
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
// track any error from EnsureWritable to propagate later
|
||||
writableError error
|
||||
}
|
||||
|
||||
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
|
||||
return &HybridCache{
|
||||
kv: kvcache.NewCausalCache(shift),
|
||||
hiddenSize: hiddenSize,
|
||||
dConv: dConv,
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *HybridCache) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for shortconv.
|
||||
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
|
||||
seqCounts := make(map[int]int)
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if _, ok := seqCounts[s]; !ok {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
// without modifying permanent state (slotForSeq, refCount)
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
slots := make([]int32, nSeqs)
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
slots[i] = int32(i)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int // track newly allocated slots that need zeroing
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
// Zero conv state for newly allocated slots to clear stale data from previous sequences
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroConvSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
// Create a tensor for the current slots
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
// Reset writable state for new forward pass
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, kvcache.ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) freeSlot(slot int) {
|
||||
// Bounds check before freeing
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroConvSlots zeros the conv state for the given slots across all layers.
|
||||
// This must be called when recycling slots to prevent stale state from affecting new sequences.
|
||||
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 || len(c.convStates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Use input context for creating tensors
|
||||
inputCtx := ctx.Input()
|
||||
|
||||
// Create slot indices tensor
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, s := range slots {
|
||||
slotIndices[i] = int32(s)
|
||||
}
|
||||
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
||||
|
||||
// Create zero tensor for the slots (SetRows requires F32 source)
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
|
||||
|
||||
// Zero each layer's conv state for these slots
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
|
||||
// Returns an error if slot allocation fails.
|
||||
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
// Copy existing conv state for all initialized layers
|
||||
for _, buf := range c.convStates {
|
||||
// buf: [dConv*hiddenSize, maxSlots]
|
||||
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild current slots tensor
|
||||
slots := make([]int32, len(c.curSlots))
|
||||
for i, v := range c.curSlots {
|
||||
slots[i] = int32(v)
|
||||
}
|
||||
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
|
||||
// On the first write to dst, EnsureWritable will create a private slot.
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
// Bounds check before decrementing
|
||||
if dstSlot >= 0 && dstSlot < len(c.refCount) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
// src may not have a slot yet; dst will allocate on demand
|
||||
return
|
||||
}
|
||||
|
||||
// Bounds check before incrementing
|
||||
if srcSlot >= 0 && srcSlot < len(c.refCount) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
||||
return c.kv.CanResume(seq, pos)
|
||||
}
|
||||
|
||||
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For recurrent state, any removal invalidates the state because
|
||||
// the state at position N depends on all previous positions.
|
||||
// Drop the slot mapping so it resets on next use.
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bounds check
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
|
||||
// Returns an error if copy-on-write allocation fails.
|
||||
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
cur := buf.Rows(ctx, c.slotsTensor())
|
||||
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes a new conv state for current batch sequences.
|
||||
// newState must have shape [dConv, hiddenSize, nSeqs].
|
||||
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
|
||||
// SetRows requires F32 source
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *HybridCache) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
}
|
||||
@@ -1,444 +0,0 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// TestHybridCache tests verify the slot management logic of HybridCache.
|
||||
// These tests focus on the recurrent state slot allocation, reference counting,
|
||||
// and copy-on-write semantics without requiring a full ML backend.
|
||||
|
||||
// createSlotOnlyCache creates a HybridCache with only the slot management
|
||||
// fields initialized. Used to test slot logic in isolation.
|
||||
func createSlotOnlyCache(maxSequences int) *HybridCache {
|
||||
return &HybridCache{
|
||||
hiddenSize: 256,
|
||||
dConv: 3,
|
||||
maxSequences: maxSequences,
|
||||
refCount: make([]int, maxSequences),
|
||||
freeSlots: initFreeSlots(maxSequences),
|
||||
slotForSeq: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func initFreeSlots(n int) []int {
|
||||
slots := make([]int, 0, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
slots = append(slots, i)
|
||||
}
|
||||
return slots
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotAllocation(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Verify initial state
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate all slots
|
||||
for range 4 {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.refCount[slot] = 1
|
||||
}
|
||||
|
||||
// Should be full now
|
||||
if len(cache.freeSlots) != 0 {
|
||||
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Trying to allocate another should fail
|
||||
_, err := cache.allocSlot()
|
||||
if err != kvcache.ErrKvCacheFull {
|
||||
t.Errorf("expected ErrKvCacheFull, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotReuse(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate a slot
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free it
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
|
||||
// Allocate again - should get the same slot back (LIFO)
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Simulate sharing slot with seq 2 (copy-on-write style)
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Should share the same slot
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
}
|
||||
|
||||
// Ref count should be 2
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share with seq 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Unshare seq 2
|
||||
cache.refCount[slot1]--
|
||||
delete(cache.slotForSeq, 2)
|
||||
|
||||
// Ref count should be back to 1
|
||||
if cache.refCount[slot1] != 1 {
|
||||
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Seq 2 should no longer have a slot
|
||||
if _, ok := cache.slotForSeq[2]; ok {
|
||||
t.Error("seq 2 should not have a slot after unshare")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot when refCount drops to 0
|
||||
cache.refCount[slot1]--
|
||||
if cache.refCount[slot1] <= 0 {
|
||||
cache.refCount[slot1] = 0
|
||||
cache.freeSlot(slot1)
|
||||
}
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Slot should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Ref count should be 0
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotOverwrite(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slots for seq 1 and seq 2
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
slot2, _ := cache.allocSlot()
|
||||
cache.slotForSeq[2] = slot2
|
||||
cache.refCount[slot2] = 1
|
||||
|
||||
initialFreeSlots := len(cache.freeSlots)
|
||||
|
||||
// Simulate overwriting seq 2's slot with slot1 (sharing)
|
||||
// First free the old slot
|
||||
cache.refCount[slot2]--
|
||||
if cache.refCount[slot2] <= 0 {
|
||||
cache.refCount[slot2] = 0
|
||||
cache.freeSlot(slot2)
|
||||
}
|
||||
// Then share slot1
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Seq 2 should now share slot1
|
||||
if cache.slotForSeq[2] != slot1 {
|
||||
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
|
||||
}
|
||||
|
||||
// Old slot2 should be freed
|
||||
if len(cache.freeSlots) != initialFreeSlots+1 {
|
||||
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_BoundsChecking(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Test freeing invalid slot (should not panic)
|
||||
cache.freeSlot(-1)
|
||||
cache.freeSlot(100) // out of bounds
|
||||
|
||||
// freeSlot does bounds checking, so invalid slots should be ignored
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Fork to seq 2, 3, 4 (all share slot1)
|
||||
for _, seq := range []int{2, 3, 4} {
|
||||
cache.slotForSeq[seq] = slot1
|
||||
cache.refCount[slot1]++
|
||||
}
|
||||
|
||||
// Ref count should be 4
|
||||
if cache.refCount[slot1] != 4 {
|
||||
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Remove seq 2, 3
|
||||
for _, seq := range []int{2, 3} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 2 {
|
||||
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
|
||||
}
|
||||
|
||||
// Slot should still be allocated (not in free list)
|
||||
found := false
|
||||
for _, s := range cache.freeSlots {
|
||||
if s == slot1 {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
t.Error("slot1 should not be in free list yet")
|
||||
}
|
||||
|
||||
// Remove remaining sequences
|
||||
for _, seq := range []int{1, 4} {
|
||||
delete(cache.slotForSeq, seq)
|
||||
cache.refCount[slot1]--
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 0 {
|
||||
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ChainedSharing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(8)
|
||||
|
||||
// Create seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Share 1 -> 2
|
||||
cache.slotForSeq[2] = slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// Share 2 -> 3 (should still share slot1)
|
||||
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
|
||||
cache.refCount[slot1]++
|
||||
|
||||
// All should share slot1
|
||||
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
|
||||
t.Error("all sequences should share slot1")
|
||||
}
|
||||
|
||||
if cache.refCount[slot1] != 3 {
|
||||
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_CacheParameters(t *testing.T) {
|
||||
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
|
||||
|
||||
if cache.hiddenSize != 512 {
|
||||
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
|
||||
}
|
||||
if cache.dConv != 5 {
|
||||
t.Errorf("expected dConv 5, got %d", cache.dConv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_NumSeqs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially no sequences
|
||||
if cache.numSeqs() != 0 {
|
||||
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
if cache.numSeqs() != 3 {
|
||||
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_SeqTokens(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially 0
|
||||
if cache.seqTokens() != 0 {
|
||||
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
|
||||
// Manually set up current batch state
|
||||
cache.curSeqTokens = 16
|
||||
|
||||
if cache.seqTokens() != 16 {
|
||||
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
|
||||
}
|
||||
}
|
||||
|
||||
// Test that Seqs returns a clone of curSeqs
|
||||
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
cache.curSeqs = []int{1, 2, 3}
|
||||
|
||||
seqs := cache.Seqs()
|
||||
|
||||
// Modify returned slice
|
||||
seqs[0] = 999
|
||||
|
||||
// Original should be unchanged
|
||||
if cache.curSeqs[0] != 1 {
|
||||
t.Error("Seqs should return a clone, not the original slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Initially not supported (no batch set up)
|
||||
if cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be false initially")
|
||||
}
|
||||
|
||||
// Set up a valid batch
|
||||
cache.curSeqTokens = 1
|
||||
cache.curSeqs = []int{1}
|
||||
|
||||
if !cache.IsSupportedForBatch() {
|
||||
t.Error("expected IsSupportedForBatch to be true with valid batch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// zeroConvSlots should handle empty slots without panicking
|
||||
cache.zeroConvSlots(nil, nil)
|
||||
cache.zeroConvSlots(nil, []int{})
|
||||
|
||||
// zeroConvSlots should handle empty convStates without panicking
|
||||
cache.zeroConvSlots(nil, []int{0, 1, 2})
|
||||
}
|
||||
|
||||
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Allocate slot for seq 1
|
||||
slot1, _ := cache.allocSlot()
|
||||
cache.slotForSeq[1] = slot1
|
||||
cache.refCount[slot1] = 1
|
||||
|
||||
// Free the slot (simulating sequence removal)
|
||||
cache.refCount[slot1]--
|
||||
cache.freeSlot(slot1)
|
||||
delete(cache.slotForSeq, 1)
|
||||
|
||||
// Verify slot is in free list
|
||||
if len(cache.freeSlots) != 4 {
|
||||
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
|
||||
}
|
||||
|
||||
// Allocate for new seq 2 - should get recycled slot
|
||||
slot2, _ := cache.allocSlot()
|
||||
if slot2 != slot1 {
|
||||
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
|
||||
}
|
||||
|
||||
// This recycled slot would need zeroing in the real implementation
|
||||
// The actual zeroing is tested via integration tests since it requires ML context
|
||||
}
|
||||
|
||||
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
|
||||
cache := createSlotOnlyCache(4)
|
||||
|
||||
// Simulate the slot allocation flow from StartForward
|
||||
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
|
||||
|
||||
newSlots := []int{}
|
||||
|
||||
// Seq 1 doesn't have a slot - allocate and track
|
||||
seq := 1
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, err := cache.allocSlot()
|
||||
if err != nil {
|
||||
t.Fatalf("allocSlot failed: %v", err)
|
||||
}
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
|
||||
// Verify newSlots contains the allocated slot
|
||||
if len(newSlots) != 1 {
|
||||
t.Errorf("expected 1 new slot, got %d", len(newSlots))
|
||||
}
|
||||
|
||||
// Seq 1 already has a slot - should NOT be tracked as new
|
||||
newSlots2 := []int{}
|
||||
if _, ok := cache.slotForSeq[seq]; !ok {
|
||||
slot, _ := cache.allocSlot()
|
||||
cache.slotForSeq[seq] = slot
|
||||
cache.refCount[slot] = 1
|
||||
newSlots2 = append(newSlots2, slot)
|
||||
}
|
||||
|
||||
// Verify no new slots for existing sequence
|
||||
if len(newSlots2) != 0 {
|
||||
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
|
||||
}
|
||||
}
|
||||
@@ -1,253 +0,0 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
hiddenSize int
|
||||
headDim, ropeDim int
|
||||
|
||||
eps, ropeBase, ropeScale float32
|
||||
|
||||
ropeType string
|
||||
originalContextLength int
|
||||
|
||||
// per-layer head counts (LFM2 alternates attention and recurrent layers)
|
||||
numHeadsByLayer []int
|
||||
numKVHeadsByLayer []int
|
||||
}
|
||||
|
||||
func (o Options) headDimValue() int {
|
||||
// Head dim is shared across layers; fall back to first attention layer head count.
|
||||
for _, h := range o.numHeadsByLayer {
|
||||
if h > 0 {
|
||||
return cmp.Or(o.headDim, o.hiddenSize/h)
|
||||
}
|
||||
}
|
||||
return cmp.Or(o.headDim, o.hiddenSize)
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
opts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
rope.WithAttentionFactor(attnFactor),
|
||||
)
|
||||
}
|
||||
|
||||
headCount := 1
|
||||
for _, h := range o.numHeadsByLayer {
|
||||
if h > 0 {
|
||||
headCount = h
|
||||
break
|
||||
}
|
||||
}
|
||||
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/headCount), o.ropeBase, 1./o.ropeScale, opts...)
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
if c.Uint("expert_count") > 0 {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
if c.String("tokenizer.ggml.model") != "gpt2" {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
var pretokenizers []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
case "default":
|
||||
// use default BPE pretokenizer
|
||||
default:
|
||||
// llama-bpe style (default for LFM2)
|
||||
pretokenizers = []string{
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
}
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
},
|
||||
}
|
||||
|
||||
type headCounts interface {
|
||||
HeadCount() []uint64
|
||||
HeadCountKV() []uint64
|
||||
}
|
||||
hc, ok := c.(headCounts)
|
||||
if !ok {
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
|
||||
headCount := hc.HeadCount()
|
||||
headCountKV := hc.HeadCountKV()
|
||||
|
||||
m.numHeadsByLayer = make([]int, len(m.Layers))
|
||||
m.numKVHeadsByLayer = make([]int, len(m.Layers))
|
||||
for i := range m.Layers {
|
||||
m.numHeadsByLayer[i] = int(headCount[i])
|
||||
m.numKVHeadsByLayer[i] = int(headCountKV[i])
|
||||
|
||||
if m.numKVHeadsByLayer[i] == 0 {
|
||||
m.Layers[i].Operator = &ShortConv{}
|
||||
} else {
|
||||
m.Layers[i].Operator = &Attention{}
|
||||
}
|
||||
}
|
||||
|
||||
lCache := int(c.Uint("shortconv.l_cache"))
|
||||
dConv := max(0, lCache-1)
|
||||
m.Cache = NewHybridCache(m.Shift, m.hiddenSize, dConv)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type Operator interface {
|
||||
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
|
||||
}
|
||||
|
||||
func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
headDim := opts.headDimValue()
|
||||
numHeads := opts.numHeadsByLayer[layer]
|
||||
numKVHeads := opts.numKVHeadsByLayer[layer]
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
query = query.Reshape(ctx, headDim, numHeads, batchSize)
|
||||
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
|
||||
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Operator Operator
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.Operator.Forward(ctx, hiddenState, positions, cache, layer, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, outputs, m.Cache.(*HybridCache), &m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("lfm2", New)
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package lfm2
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
type shortConvKernel struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
// ShortConv implements the LFM2 short-convolution block (GGML_OP_SSM_CONV) with a recurrent
|
||||
// state stored in the HybridCache.
|
||||
type ShortConv struct {
|
||||
Conv *shortConvKernel `gguf:"shortconv.conv"`
|
||||
InProj *nn.Linear `gguf:"shortconv.in_proj"`
|
||||
OutProj *nn.Linear `gguf:"shortconv.out_proj"`
|
||||
}
|
||||
|
||||
func (sc *ShortConv) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
|
||||
nSeqs := cache.numSeqs()
|
||||
seqTokens := cache.seqTokens()
|
||||
hiddenSize := hiddenStates.Dim(0)
|
||||
if nSeqs <= 0 || seqTokens <= 0 || hiddenStates.Dim(1) != nSeqs*seqTokens {
|
||||
panic("lfm2: unsupported batch layout for shortconv")
|
||||
}
|
||||
|
||||
bcx := sc.InProj.Forward(ctx, hiddenStates).Reshape(ctx, 3*hiddenSize, seqTokens, nSeqs)
|
||||
|
||||
elementSize := bcx.Stride(0)
|
||||
b := bcx.View(ctx, 0*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
|
||||
c := bcx.View(ctx, 1*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
|
||||
x := bcx.View(ctx, 2*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
|
||||
|
||||
bx := b.Mul(ctx, x).Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
state, err := cache.ConvState(ctx, layer)
|
||||
if err != nil {
|
||||
panic("lfm2: failed to get conv state: " + err.Error())
|
||||
}
|
||||
sx := state.Concat(ctx, bx, 0)
|
||||
|
||||
convOut := sx.SSMConv(ctx, sc.Conv.Weight)
|
||||
y := c.Mul(ctx, convOut)
|
||||
|
||||
dConv := sx.Dim(0) - seqTokens
|
||||
cache.UpdateConvState(ctx, layer, sx.Slice(ctx, 0, sx.Dim(0)-dConv, sx.Dim(0), 1))
|
||||
|
||||
return sc.OutProj.Forward(ctx, y.Reshape(ctx, hiddenSize, seqTokens*nSeqs))
|
||||
}
|
||||
@@ -7,9 +7,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
_ "github.com/ollama/ollama/model/models/lfm2"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
|
||||
@@ -1,410 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type glm46ParserState int
|
||||
|
||||
const (
|
||||
glm46ParserState_LookingForThinkingOpen glm46ParserState = iota
|
||||
glm46ParserState_ThinkingStartedEatingWhitespace
|
||||
glm46ParserState_CollectingThinking
|
||||
glm46ParserState_ThinkingDoneEatingWhitespace
|
||||
glm46ParserState_CollectingContent
|
||||
glm46ParserState_ToolStartedEatingWhitespace
|
||||
glm46ParserState_CollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
glm46ThinkingOpenTag = "<think>"
|
||||
glm46ThinkingCloseTag = "</think>"
|
||||
glm46ToolOpenTag = "<tool_call>"
|
||||
glm46ToolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
type GLM46Parser struct {
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
return tools
|
||||
}
|
||||
|
||||
type glm46Event interface {
|
||||
isGLM46Event()
|
||||
}
|
||||
|
||||
type glm46EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (glm46EventContent) isGLM46Event() {}
|
||||
|
||||
type glm46EventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (glm46EventRawToolCall) isGLM46Event() {}
|
||||
|
||||
type glm46EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (glm46EventThinkingContent) isGLM46Event() {}
|
||||
|
||||
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case glm46EventRawToolCall:
|
||||
toolCall, err := parseGLM46ToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case glm46EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case glm46EventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
// events, we naively append them together here.
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) parseEvents() []glm46Event {
|
||||
var all []glm46Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []glm46Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4.6 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
|
||||
// and transitions to the next state. Returns (nil, false) if only whitespace remains
|
||||
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
|
||||
func (p *GLM46Parser) eatLeadingWhitespaceAndTransitionTo(nextState glm46ParserState) ([]glm46Event, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false // Still only whitespace, keep waiting for more input
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true // Successfully transitioned
|
||||
}
|
||||
|
||||
// glm46SplitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
|
||||
// the content after (optionally trimmed of leading whitespace), and updates the buffer
|
||||
func glm46SplitAtTag(p *GLM46Parser, tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
|
||||
var events []glm46Event
|
||||
|
||||
switch p.state {
|
||||
case glm46ParserState_LookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, glm46ThinkingOpenTag) {
|
||||
// Found <think> opening tag
|
||||
after := strings.TrimPrefix(trimmed, glm46ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = glm46ParserState_ThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(glm46ThinkingOpenTag, trimmed) {
|
||||
// Partial opening tag seen, keep accumulating
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
// Only whitespace, keep accumulating
|
||||
return events, false
|
||||
} else {
|
||||
// No thinking tag found, skip to content collection
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
// Don't trim - we want to keep the original content
|
||||
return events, true
|
||||
}
|
||||
|
||||
case glm46ParserState_ThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingThinking)
|
||||
|
||||
case glm46ParserState_CollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, glm46ThinkingCloseTag) {
|
||||
thinking, remaining := glm46SplitAtTag(p, glm46ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = glm46ParserState_ThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, glm46ThinkingCloseTag); overlapLen > 0 {
|
||||
// Partial closing tag - withhold it along with any trailing whitespace before it
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case glm46ParserState_ThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingContent)
|
||||
|
||||
case glm46ParserState_CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), glm46ToolOpenTag) {
|
||||
before, after := glm46SplitAtTag(p, glm46ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, glm46EventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = glm46ParserState_ToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), glm46ToolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case glm46ParserState_ToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingToolContent)
|
||||
|
||||
case glm46ParserState_CollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, glm46ToolCloseTag) {
|
||||
toolContent, _ := glm46SplitAtTag(p, glm46ToolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("glm46 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, glm46EventRawToolCall{raw: toolContent})
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
return events, true
|
||||
} else {
|
||||
// Keep accumulating - tool calls are not streamed
|
||||
// We just wait for the closing tag
|
||||
return events, false
|
||||
}
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// GLMToolCallXML represents the structure of a GLM-4.6 tool call for XML parsing
|
||||
type GLMToolCallXML struct {
|
||||
XMLName xml.Name `xml:"tool_call"`
|
||||
Content string `xml:",chardata"` // Function name (text nodes between tags)
|
||||
Keys []string `xml:"arg_key"` // All arg_key elements in document order
|
||||
Values []string `xml:"arg_value"` // All arg_value elements in document order
|
||||
}
|
||||
|
||||
// escapeGLM46Content escapes XML entities in text content while preserving arg_key/arg_value tags
|
||||
func escapeGLM46Content(s string) string {
|
||||
var result strings.Builder
|
||||
inTag := false
|
||||
|
||||
for i := range len(s) {
|
||||
ch := s[i]
|
||||
|
||||
if ch == '<' {
|
||||
// Check if this is a known tag
|
||||
if strings.HasPrefix(s[i:], "<arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "<arg_value>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_value>") {
|
||||
inTag = true
|
||||
}
|
||||
}
|
||||
|
||||
if inTag {
|
||||
result.WriteByte(ch)
|
||||
if ch == '>' {
|
||||
inTag = false
|
||||
}
|
||||
} else {
|
||||
// Escape special characters in text content
|
||||
switch ch {
|
||||
case '&':
|
||||
result.WriteString("&")
|
||||
case '<':
|
||||
result.WriteString("<")
|
||||
case '>':
|
||||
result.WriteString(">")
|
||||
default:
|
||||
result.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
// Escape any unescaped entities in text content
|
||||
// We need to escape text between tags, but not the tags themselves
|
||||
escaped := escapeGLM46Content(raw.raw)
|
||||
|
||||
// Wrap the content in a root element to make it valid XML
|
||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
||||
|
||||
// Parse XML into struct
|
||||
var parsed GLMToolCallXML
|
||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
}
|
||||
|
||||
// Extract and trim function name
|
||||
functionName := strings.TrimSpace(parsed.Content)
|
||||
if functionName == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
// Verify keys and values are paired correctly
|
||||
if len(parsed.Keys) != len(parsed.Values) {
|
||||
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionName {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Build arguments map by pairing keys and values
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: functionName,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range parsed.Keys {
|
||||
key := strings.TrimSpace(parsed.Keys[i])
|
||||
value := parsed.Values[i] // Don't trim here - parseValue handles it
|
||||
|
||||
// Look up parameter type
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
|
||||
// Handle anyOf by collecting all types from the union
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
paramType = append(paramType, anyOfProp.Type...)
|
||||
}
|
||||
} else {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse value with type coercion
|
||||
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
@@ -1,862 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM46ParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []glm46Event
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
steps []step
|
||||
only bool
|
||||
}{
|
||||
{
|
||||
desc: "leading whitespace before think tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " \n\t ",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "<think>thinking</think>",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "think tag with whitespace inside",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think> \n thinking content \n </think>regular content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking content"},
|
||||
glm46EventContent{content: "regular content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with leading whitespace after opening tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call> \n test \n </tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "simple thinking then content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>I am thinking</think>Now I respond",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "I am thinking"},
|
||||
glm46EventContent{content: "Now I respond"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streamed thinking content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>hello",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "hello"}},
|
||||
},
|
||||
{
|
||||
input: " world",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: " world"}},
|
||||
},
|
||||
{
|
||||
input: "</think>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>Let me call a tool</think>here is text<tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "Let me call a tool"},
|
||||
glm46EventContent{content: "here is text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>\n</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with content after",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think><tool_call>test</tool_call>after tool",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "after tool"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between content and tool call is trimmed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think>content\n \t <tool_call>test</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventContent{content: "content"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between tool call and content is trimmed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think><tool_call>test</tool_call>\n\t after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "think"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split thinking close tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking content</th",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking content"}},
|
||||
},
|
||||
{
|
||||
input: "ink>after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split thinking open tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <thi",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "nk>content</think>",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split tool open tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think>content<tool",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "think"}, glm46EventContent{content: "content"}},
|
||||
},
|
||||
{
|
||||
input: "_call>inside",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "inside"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial thinking close tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>content</th",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
|
||||
},
|
||||
{
|
||||
input: "ought more",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "</thought more"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial thinking open tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <thi",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "nking is fun",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: " <thinking is fun"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tool open tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think>content\n<tool",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: " fakeout",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "\n<tool fakeout"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tool close tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call>content</tool",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: " fakeout",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "content</tool fakeout"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think>content here",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content here"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple tool calls in sequence",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think><tool_call>first</tool_call>between<tool_call>second</tool_call>end",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "think"},
|
||||
glm46EventRawToolCall{raw: "first"},
|
||||
glm46EventContent{content: "between"},
|
||||
glm46EventRawToolCall{raw: "second"},
|
||||
glm46EventContent{content: "end"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - direct to content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "just content here",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "just content here"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - skip to content then tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "Here's the answer:<tool_call>test</tool_call>done",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "Here's the answer:"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "done"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - whitespace preserved when no tags",
|
||||
steps: []step{
|
||||
{
|
||||
input: " \n content with leading whitespace",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: " \n content with leading whitespace"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after think close tag gets eaten",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think> \n\t content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after tool_call close tag gets eaten",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call>test</tool_call> \n\t content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace (single chunk)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking content ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace with newlines",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking\n\n ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content trailing whitespace emitted when more content arrives",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "more thinking",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: " more thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace before partial close tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking </th",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "ink>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anyOnlies := false
|
||||
for _, tc := range cases {
|
||||
if tc.only {
|
||||
anyOnlies = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
if anyOnlies && !tc.only {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := GLM46Parser{}
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGLMToolCallXMLOrderPreservation verifies that xml.Unmarshal preserves
|
||||
// document order when collecting multiple elements with the same tag name into slices.
|
||||
// This is a critical assumption for the GLM-4.6 parser's struct-based approach.
|
||||
func TestGLMToolCallXMLOrderPreservation(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
xml string
|
||||
wantKeys []string
|
||||
wantValues []string
|
||||
}{
|
||||
{
|
||||
name: "alternating keys and values",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>first</arg_key>
|
||||
<arg_value>A</arg_value>
|
||||
<arg_key>second</arg_key>
|
||||
<arg_value>B</arg_value>
|
||||
<arg_key>third</arg_key>
|
||||
<arg_value>C</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"first", "second", "third"},
|
||||
wantValues: []string{"A", "B", "C"},
|
||||
},
|
||||
{
|
||||
name: "all keys then all values",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>key1</arg_key>
|
||||
<arg_key>key2</arg_key>
|
||||
<arg_key>key3</arg_key>
|
||||
<arg_value>val1</arg_value>
|
||||
<arg_value>val2</arg_value>
|
||||
<arg_value>val3</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"key1", "key2", "key3"},
|
||||
wantValues: []string{"val1", "val2", "val3"},
|
||||
},
|
||||
{
|
||||
name: "mixed grouping",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>a</arg_key>
|
||||
<arg_value>1</arg_value>
|
||||
<arg_key>b</arg_key>
|
||||
<arg_key>c</arg_key>
|
||||
<arg_value>2</arg_value>
|
||||
<arg_value>3</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"a", "b", "c"},
|
||||
wantValues: []string{"1", "2", "3"},
|
||||
},
|
||||
{
|
||||
name: "reverse order - all values then all keys",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_value>X</arg_value>
|
||||
<arg_value>Y</arg_value>
|
||||
<arg_value>Z</arg_value>
|
||||
<arg_key>x</arg_key>
|
||||
<arg_key>y</arg_key>
|
||||
<arg_key>z</arg_key>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"x", "y", "z"},
|
||||
wantValues: []string{"X", "Y", "Z"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var parsed GLMToolCallXML
|
||||
err := xml.Unmarshal([]byte(tc.xml), &parsed)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal XML: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(parsed.Keys, tc.wantKeys) {
|
||||
t.Errorf("Keys order mismatch:\ngot: %v\nwant: %v", parsed.Keys, tc.wantKeys)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(parsed.Values, tc.wantValues) {
|
||||
t.Errorf("Values order mismatch:\ngot: %v\nwant: %v", parsed.Values, tc.wantValues)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM46ToolCallParsing(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
rawToolCall string
|
||||
tools []api.Tool
|
||||
wantToolCall api.ToolCall
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "simple tool call",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `get-current-weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>New York, NY</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-current-weather",
|
||||
Arguments: args(`{"location": "New York, NY", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with typed parameters",
|
||||
tools: []api.Tool{
|
||||
tool("calculate", map[string]api.ToolProperty{
|
||||
"x": {Type: api.PropertyType{"number"}},
|
||||
"y": {Type: api.PropertyType{"integer"}},
|
||||
"enabled": {Type: api.PropertyType{"boolean"}},
|
||||
"items": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `calculate
|
||||
<arg_key>x</arg_key>
|
||||
<arg_value>3.14</arg_value>
|
||||
<arg_key>y</arg_key>
|
||||
<arg_value>42</arg_value>
|
||||
<arg_key>enabled</arg_key>
|
||||
<arg_value>true</arg_value>
|
||||
<arg_key>items</arg_key>
|
||||
<arg_value>["a", "b", "c"]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: args(`{"enabled": true, "items": ["a", "b", "c"], "x": 3.14, "y": 42}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "function name with whitespace",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: ` get-weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Paris</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-weather",
|
||||
Arguments: args(`{"city": "Paris"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "values with special characters",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `execute-command
|
||||
<arg_key>command</arg_key>
|
||||
<arg_value>ls && echo "done"</arg_value>
|
||||
<arg_key>message</arg_key>
|
||||
<arg_value>a < b and c > d</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "execute-command",
|
||||
Arguments: args(`{"command": "ls && echo \"done\"", "message": "a < b and c > d"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unicode in function names and values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `获取天气
|
||||
<arg_key>城市</arg_key>
|
||||
<arg_value>北京</arg_value>
|
||||
<arg_key>message</arg_key>
|
||||
<arg_value>Hello! 你好! 🌟</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: args(`{"message": "Hello! 你好! 🌟", "城市": "北京"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty value",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param1</arg_key>
|
||||
<arg_value></arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param1": ""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "special chars in arg_key names",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param<1></arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
<arg_key>a&b</arg_key>
|
||||
<arg_value>value2</arg_value>
|
||||
<arg_key>x>y</arg_key>
|
||||
<arg_value>value3</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"a&b": "value2", "param<1>": "value1", "x>y": "value3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple consecutive ampersands",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>test &&&& more</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param": "test &&&& more"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed special chars together",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value><>&<>&</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param": "<>&<>&"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "newlines and tabs in parameter values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>multiline</arg_key>
|
||||
<arg_value>line1
|
||||
indented line2
|
||||
line3</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"multiline": "line1\n\tindented line2\nline3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single and double quotes in values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>quotes</arg_key>
|
||||
<arg_value>She said "Hello's there!"</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"quotes": "She said \"Hello's there!\""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CDATA-like content that should be treated as text",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>cdata</arg_key>
|
||||
<arg_value><![CDATA[not actual cdata]]></arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"cdata": "<![CDATA[not actual cdata]]>"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all special XML entities",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>entities</arg_key>
|
||||
<arg_value><>&'"</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"entities": "<>&'""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "order preservation with multiple parameters",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>first</arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
<arg_key>second</arg_key>
|
||||
<arg_value>value2</arg_value>
|
||||
<arg_key>third</arg_key>
|
||||
<arg_value>value3</arg_value>
|
||||
<arg_key>fourth</arg_key>
|
||||
<arg_value>value4</arg_value>
|
||||
<arg_key>fifth</arg_key>
|
||||
<arg_value>value5</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"fifth": "value5", "first": "value1", "fourth": "value4", "second": "value2", "third": "value3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "order preservation with identical key names but different positions",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>first occurrence</arg_value>
|
||||
<arg_key>other</arg_key>
|
||||
<arg_value>middle</arg_value>
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>second occurrence</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
// Later occurrence should overwrite earlier one
|
||||
Arguments: args(`{"other": "middle", "param": "second occurrence"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array with mixed types",
|
||||
tools: []api.Tool{
|
||||
tool("process", map[string]api.ToolProperty{
|
||||
"items": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `process
|
||||
<arg_key>items</arg_key>
|
||||
<arg_value>[1, "hello", true, null]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: args(`{"items": [1, "hello", true, null]}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
tools: []api.Tool{
|
||||
tool("test", map[string]api.ToolProperty{
|
||||
"tags": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `test
|
||||
<arg_key>tags</arg_key>
|
||||
<arg_value>[]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: args(`{"tags": []}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "anyOf array or string - with array of objects",
|
||||
tools: []api.Tool{
|
||||
tool("TodoWrite", map[string]api.ToolProperty{
|
||||
"todos": {AnyOf: []api.ToolProperty{{Type: api.PropertyType{"array"}}, {Type: api.PropertyType{"string"}}}},
|
||||
}),
|
||||
},
|
||||
// <tool_call>TodoWrite
|
||||
// <arg_key>todos</arg_key>
|
||||
// <arg_value>[{"content": "Set up HTML file and basic structure", "id": "1", "priority": "high", "status": "pending"}, {"content": "Create 3D scene with Three.js", "id": "2", "priority": "high", "status": "pending"}, {"content": "Implement terrain generation with blocks", "id": "3", "priority": "high", "status": "pending"}, {"content": "Add player controls (movement, camera)", "id": "4", "priority": "high", "status": "pending"}, {"content": "Implement block placement/destruction", "id": "5", "priority": "medium", "status": "pending"}, {"content": "Add lighting and textures", "id": "6", "priority": "medium", "status": "pending"}, {"content": "Test and optimize performance", "id": "7", "priority": "low", "status": "pending"}]</arg_value>
|
||||
// </tool_call>
|
||||
rawToolCall: `TodoWrite
|
||||
<arg_key>todos</arg_key>
|
||||
<arg_value>[{"content": "task 1", "status": "pending", "priority": "high", "id": "1"}, {"content": "task 2", "status": "completed", "priority": "low", "id": "2"}]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "TodoWrite",
|
||||
Arguments: args(`{"todos": [{"content": "task 1", "id": "1", "priority": "high", "status": "pending"}, {"content": "task 2", "id": "2", "priority": "low", "status": "completed"}]}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "anyOf array or string - with plain string",
|
||||
tools: []api.Tool{
|
||||
tool("TodoWrite", map[string]api.ToolProperty{
|
||||
"todos": {Type: api.PropertyType{"array", "string"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `TodoWrite
|
||||
<arg_key>todos</arg_key>
|
||||
<arg_value>Error: could not load todos</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "TodoWrite",
|
||||
Arguments: args(`{"todos": "Error: could not load todos"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gotToolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: tc.rawToolCall}, tc.tools)
|
||||
if err != nil {
|
||||
t.Errorf("case %d (%s): %v", i, tc.name, err)
|
||||
}
|
||||
if !toolCallEqual(gotToolCall, tc.wantToolCall) {
|
||||
t.Errorf("case %d (%s): got tool call %#v, want %#v", i, tc.name, gotToolCall, tc.wantToolCall)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import "github.com/ollama/ollama/api"
|
||||
|
||||
// GLM47Parser extends GLM46Parser with thinking-aware initialization.
|
||||
// GLM-4.7's prompt ends with <think> when thinking is enabled, so the parser
|
||||
// must start in CollectingThinking state (the model outputs thinking content directly).
|
||||
type GLM47Parser struct {
|
||||
GLM46Parser
|
||||
}
|
||||
|
||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
p.state = glm46ParserState_CollectingThinking
|
||||
}
|
||||
return tools
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM47ParserAdd(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init([]api.Tool{
|
||||
tool("calculate", map[string]api.ToolProperty{
|
||||
"count": {Type: api.PropertyType{"integer"}},
|
||||
"enabled": {Type: api.PropertyType{"boolean"}},
|
||||
}),
|
||||
}, nil, nil)
|
||||
|
||||
// When thinking is enabled (thinkValue nil), the prompt ends with <think>,
|
||||
// so the model output does NOT include the opening <think> tag.
|
||||
content, thinking, calls, err := parser.Add("plan</think>Answer<tool_call>calculate<arg_key>count</arg_key><arg_value>3</arg_value><arg_key>enabled</arg_key><arg_value>true</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "plan" {
|
||||
t.Fatalf("expected thinking 'plan', got %q", thinking)
|
||||
}
|
||||
if content != "Answer" {
|
||||
t.Fatalf("expected content 'Answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
expectedArgs := args(`{"count": 3, "enabled": true}`)
|
||||
if !toolCallEqual(api.ToolCall{Function: api.ToolCallFunction{Arguments: calls[0].Function.Arguments}}, api.ToolCall{Function: api.ToolCallFunction{Arguments: expectedArgs}}) {
|
||||
t.Fatalf("expected args %#v, got %#v", expectedArgs.ToMap(), calls[0].Function.Arguments.ToMap())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserNoThinkingContent(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
// When thinking is enabled but model has no thinking to output,
|
||||
// it should output </think> immediately followed by content.
|
||||
content, thinking, calls, err := parser.Add("</think>Plain answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Plain answer" {
|
||||
t.Fatalf("expected content 'Plain answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserThinkingDisabled(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
// When thinking is disabled, parser stays in LookingForThinkingOpen state
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
// Model outputs plain content (prompt ended with </think>)
|
||||
content, thinking, calls, err := parser.Add("Plain answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Plain answer" {
|
||||
t.Fatalf("expected content 'Plain answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
||||
toolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: `exec
|
||||
<arg_key>expr</arg_key>
|
||||
<arg_value>a < b && c > d</arg_value>`}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
expected := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: args(`{"expr": "a < b && c > d"}`),
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(toolCall, expected) {
|
||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||
}
|
||||
}
|
||||
@@ -1,498 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type LFM2ParserState int
|
||||
|
||||
const (
|
||||
LFM2CollectingThinking LFM2ParserState = iota
|
||||
LFM2CollectingContent
|
||||
LFM2CollectingToolCalls
|
||||
)
|
||||
|
||||
const (
|
||||
lfm2ThinkingOpenTag = "<think>"
|
||||
lfm2ThinkingCloseTag = "</think>"
|
||||
lfm2ToolCallStartTag = "<|tool_call_start|>"
|
||||
lfm2ToolCallEndTag = "<|tool_call_end|>"
|
||||
)
|
||||
|
||||
type LFM2Parser struct {
|
||||
state LFM2ParserState
|
||||
buffer strings.Builder
|
||||
hasThinkingSupport bool
|
||||
needsThinkingLeadingTrim bool // trim leading whitespace after <think> tag
|
||||
needsContentLeadingTrim bool // trim leading whitespace after </think> tag
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.ThinkValue) {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
// Check both model capability AND request preference
|
||||
thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
|
||||
|
||||
if !thinkingEnabled {
|
||||
p.state = LFM2CollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = LFM2CollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
p.state = LFM2CollectingThinking
|
||||
p.needsThinkingLeadingTrim = true
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.setInitialState(lastMessage, thinkValue)
|
||||
return tools
|
||||
}
|
||||
|
||||
type lfm2Event interface {
|
||||
isLFM2Event()
|
||||
}
|
||||
|
||||
type lfm2EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type lfm2EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type lfm2EventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (lfm2EventThinkingContent) isLFM2Event() {}
|
||||
func (lfm2EventContent) isLFM2Event() {}
|
||||
func (lfm2EventToolCall) isLFM2Event() {}
|
||||
|
||||
func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case lfm2EventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case lfm2EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case lfm2EventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) parseEvents() []lfm2Event {
|
||||
var all []lfm2Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []lfm2Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
|
||||
var events []lfm2Event
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return events, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case LFM2CollectingThinking:
|
||||
// Strip opening <think> tag if present
|
||||
if strings.HasPrefix(bufStr, lfm2ThinkingOpenTag) {
|
||||
bufStr = bufStr[len(lfm2ThinkingOpenTag):]
|
||||
p.needsThinkingLeadingTrim = true
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
}
|
||||
|
||||
// Trim leading whitespace after <think> tag (may span multiple chunks)
|
||||
if p.needsThinkingLeadingTrim {
|
||||
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
|
||||
bufStr = trimmed
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
}
|
||||
// Clear flag once we have non-whitespace content or buffer is empty
|
||||
if len(bufStr) > 0 {
|
||||
p.needsThinkingLeadingTrim = false
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(bufStr, lfm2ThinkingCloseTag) { // thinking[</think>] -> content
|
||||
split := strings.SplitN(bufStr, lfm2ThinkingCloseTag, 2)
|
||||
thinking := split[0]
|
||||
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
|
||||
|
||||
remaining := split[1]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = LFM2CollectingContent
|
||||
p.needsThinkingLeadingTrim = false
|
||||
// Set flag to trim any additional whitespace that may arrive in later chunks
|
||||
p.needsContentLeadingTrim = len(remaining) == 0
|
||||
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, lfm2EventThinkingContent{content: thinking})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, lfm2ThinkingCloseTag); overlapLen > 0 { // partial </think>
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, lfm2EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else { // otherwise its thinking content
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, lfm2EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case LFM2CollectingContent:
|
||||
// Trim leading whitespace after </think> tag (may span multiple chunks)
|
||||
if p.needsContentLeadingTrim {
|
||||
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
|
||||
bufStr = trimmed
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
}
|
||||
// Clear flag once we have non-whitespace content
|
||||
if len(bufStr) > 0 {
|
||||
p.needsContentLeadingTrim = false
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(bufStr, lfm2ToolCallStartTag) { // content[<|tool_call_start|>] -> tool calls
|
||||
split := strings.SplitN(bufStr, lfm2ToolCallStartTag, 2)
|
||||
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = LFM2CollectingToolCalls
|
||||
|
||||
if len(contentBefore) > 0 {
|
||||
events = append(events, lfm2EventContent{content: contentBefore})
|
||||
}
|
||||
return events, true
|
||||
} else { // otherwise its content
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, lfm2EventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case LFM2CollectingToolCalls:
|
||||
// Look for complete tool call JSON between tags
|
||||
if idx := strings.Index(bufStr, lfm2ToolCallEndTag); idx != -1 {
|
||||
toolCallContent := bufStr[:idx]
|
||||
|
||||
if toolCalls, err := p.parseToolCallsContent(toolCallContent); err == nil && len(toolCalls) > 0 {
|
||||
remaining := bufStr[idx+len(lfm2ToolCallEndTag):]
|
||||
|
||||
// Check if there's another tool call
|
||||
if strings.HasPrefix(remaining, lfm2ToolCallStartTag) {
|
||||
remaining = remaining[len(lfm2ToolCallStartTag):]
|
||||
} else {
|
||||
// No more tool calls, go back to content
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
p.state = LFM2CollectingContent
|
||||
}
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
for _, tc := range toolCalls {
|
||||
events = append(events, lfm2EventToolCall{toolCall: tc})
|
||||
}
|
||||
return events, true
|
||||
} else if err != nil {
|
||||
slog.Warn("lfm2 tool call parsing failed", "error", err, "content", toolCallContent)
|
||||
}
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
// parseToolCallsContent parses one or more tool calls from content
|
||||
// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)]
|
||||
func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) {
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Try JSON format first: {"name": "func", "arguments": {...}}
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
|
||||
var args api.ToolCallFunctionArguments
|
||||
if len(parsed.Arguments) > 0 {
|
||||
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
args = api.NewToolCallFunctionArguments()
|
||||
}
|
||||
|
||||
return []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: parsed.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
|
||||
return p.parsePythonStyleToolCalls(content)
|
||||
}
|
||||
|
||||
// parsePythonStyleToolCalls parses one or more Python-style tool calls
|
||||
// Examples: [bash(command='ls'),bash(command='pwd')] or bash(command='ls')
|
||||
func (p *LFM2Parser) parsePythonStyleToolCalls(content string) ([]api.ToolCall, error) {
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Strip outer brackets if present: [func(...)] -> func(...)
|
||||
if strings.HasPrefix(content, "[") && strings.HasSuffix(content, "]") {
|
||||
content = content[1 : len(content)-1]
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
// Parse multiple function calls separated by commas at the top level
|
||||
for len(content) > 0 {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
break
|
||||
}
|
||||
|
||||
// Skip leading comma from previous iteration
|
||||
if strings.HasPrefix(content, ",") {
|
||||
content = strings.TrimSpace(content[1:])
|
||||
if content == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Find function name
|
||||
parenIdx := strings.Index(content, "(")
|
||||
if parenIdx == -1 {
|
||||
return nil, errors.New("invalid tool call: no opening parenthesis")
|
||||
}
|
||||
|
||||
funcName := strings.TrimSpace(content[:parenIdx])
|
||||
if funcName == "" {
|
||||
return nil, errors.New("invalid tool call: empty function name")
|
||||
}
|
||||
|
||||
// Find matching closing parenthesis
|
||||
closeIdx := findMatchingParen(content, parenIdx)
|
||||
if closeIdx == -1 {
|
||||
return nil, errors.New("invalid tool call: no matching closing parenthesis")
|
||||
}
|
||||
|
||||
argsStr := content[parenIdx+1 : closeIdx]
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
|
||||
if argsStr != "" {
|
||||
if err := parsePythonArgs(argsStr, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: funcName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
|
||||
// Move past this function call
|
||||
content = content[closeIdx+1:]
|
||||
}
|
||||
|
||||
if len(toolCalls) == 0 {
|
||||
return nil, errors.New("no tool calls found")
|
||||
}
|
||||
|
||||
return toolCalls, nil
|
||||
}
|
||||
|
||||
// findMatchingParen finds the index of the closing parenthesis matching the one at openIdx
|
||||
// Returns -1 if not found. Handles nested parentheses and quoted strings.
|
||||
func findMatchingParen(s string, openIdx int) int {
|
||||
depth := 1
|
||||
i := openIdx + 1
|
||||
for i < len(s) && depth > 0 {
|
||||
switch s[i] {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i
|
||||
}
|
||||
case '\'', '"':
|
||||
// Skip quoted string
|
||||
quote := s[i]
|
||||
i++
|
||||
for i < len(s) && s[i] != quote {
|
||||
if s[i] == '\\' && i+1 < len(s) {
|
||||
i++ // skip escaped char
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
i++
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// parseToolCallContent parses a single tool call (for backward compatibility with tests)
|
||||
func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error) {
|
||||
calls, err := p.parseToolCallsContent(content)
|
||||
if err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
if len(calls) == 0 {
|
||||
return api.ToolCall{}, errors.New("no tool call found")
|
||||
}
|
||||
return calls[0], nil
|
||||
}
|
||||
|
||||
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
|
||||
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
|
||||
// Simple state machine to parse key='value' pairs
|
||||
// Handles: command='ls', flag="-la", count=42, enabled=true
|
||||
var key string
|
||||
i := 0
|
||||
|
||||
for i < len(argsStr) {
|
||||
// Skip whitespace
|
||||
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
|
||||
i++
|
||||
}
|
||||
if i >= len(argsStr) {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse key
|
||||
keyStart := i
|
||||
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
|
||||
i++
|
||||
}
|
||||
if i >= len(argsStr) || argsStr[i] != '=' {
|
||||
return errors.New("invalid argument: expected '='")
|
||||
}
|
||||
key = strings.TrimSpace(argsStr[keyStart:i])
|
||||
i++ // skip '='
|
||||
|
||||
// Skip whitespace after =
|
||||
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
|
||||
i++
|
||||
}
|
||||
|
||||
// Parse value
|
||||
var value string
|
||||
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
|
||||
// Quoted string
|
||||
quote := argsStr[i]
|
||||
i++
|
||||
valueStart := i
|
||||
for i < len(argsStr) && argsStr[i] != quote {
|
||||
if argsStr[i] == '\\' && i+1 < len(argsStr) {
|
||||
i += 2 // skip escaped char
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
value = argsStr[valueStart:i]
|
||||
if i < len(argsStr) {
|
||||
i++ // skip closing quote
|
||||
}
|
||||
args.Set(key, value)
|
||||
} else {
|
||||
// Unquoted value (number, bool, etc)
|
||||
valueStart := i
|
||||
for i < len(argsStr) && argsStr[i] != ',' {
|
||||
i++
|
||||
}
|
||||
value = strings.TrimSpace(argsStr[valueStart:i])
|
||||
|
||||
// Try to parse as number or bool
|
||||
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
args.Set(key, v)
|
||||
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
|
||||
args.Set(key, v)
|
||||
} else if value == "true" {
|
||||
args.Set(key, true)
|
||||
} else if value == "false" {
|
||||
args.Set(key, false)
|
||||
} else {
|
||||
args.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Skip comma and whitespace
|
||||
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -68,12 +68,6 @@ func ParserForName(name string) Parser {
|
||||
return &Nemotron3NanoParser{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaParser{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Parser{}
|
||||
case "lfm2":
|
||||
return &LFM2Parser{hasThinkingSupport: false}
|
||||
case "lfm2-thinking":
|
||||
return &LFM2Parser{hasThinkingSupport: true}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -96,11 +96,3 @@ func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func args(s string) api.ToolCallFunctionArguments {
|
||||
var result api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in args(): " + err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type GLM46Renderer struct{}
|
||||
|
||||
func (r *GLM46Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
var lastUserIndex int
|
||||
for i, message := range messages {
|
||||
if message.Role == "user" {
|
||||
lastUserIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(string(d) + "\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}\n")
|
||||
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
|
||||
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
|
||||
sb.WriteString("...\n")
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
|
||||
sb.WriteString("/nothink")
|
||||
}
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if i > lastUserIndex {
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("\n<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("\n<think></think>")
|
||||
}
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString("\n" + message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
|
||||
for key, value := range toolCall.Function.Arguments.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
|
||||
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
|
||||
}
|
||||
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt
|
||||
sb.WriteString("<|assistant|>")
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
sb.WriteString("\n<think></think>\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM46Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
skip string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
You are a helpful assistant.<|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with user assistant user",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is the capital of France?"},
|
||||
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
|
||||
{Role: "user", Content: "Fantastic!"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
What is the capital of France?<|assistant|>
|
||||
The capital of France is Paris.<|user|>
|
||||
Fantastic!<|assistant|>`,
|
||||
},
|
||||
{
|
||||
skip: "tool call ordering not guaranteed yet",
|
||||
name: "tools",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
skip: "tool call ordering not guaranteed yet",
|
||||
name: "tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo, Japan", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Japan", "unit": "fahrenheit"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>
|
||||
<think></think>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Tokyo, Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call><|observation|>
|
||||
<tool_response>
|
||||
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"temperature": 68, "weather": "sunny", "humidity": 75}
|
||||
</tool_response><|assistant|>
|
||||
<think></think>
|
||||
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think true",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think false",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?/nothink<|assistant|>
|
||||
<think></think>
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.skip != "" {
|
||||
t.Skip(tt.skip)
|
||||
}
|
||||
renderer := &GLM46Renderer{}
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
t.Logf("Got:\n%s", rendered)
|
||||
t.Logf("Expected:\n%s", tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,170 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// GLM47Renderer renders messages for GLM-4.7 models.
|
||||
//
|
||||
// GLM-4.7 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
|
||||
//
|
||||
// 1. INTERLEAVED THINKING
|
||||
// The model thinks between tool calls and after receiving tool results.
|
||||
// This enables complex step-by-step reasoning: interpreting each tool output
|
||||
// before deciding what to do next. Thinking blocks are preserved and returned
|
||||
// with tool results to maintain reasoning continuity.
|
||||
//
|
||||
// 2. PRESERVED THINKING
|
||||
// The model retains reasoning content from previous assistant turns in context.
|
||||
// This preserves reasoning continuity across multi-turn conversations. The
|
||||
// upstream API has a "clear_thinking" parameter to control this:
|
||||
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
|
||||
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
|
||||
//
|
||||
// 3. TURN-LEVEL THINKING
|
||||
// Controls whether the model should reason on each turn. The upstream API
|
||||
// uses "enable_thinking" parameter:
|
||||
// - enable_thinking=true: outputs <think> to start reasoning
|
||||
// - enable_thinking=false: outputs </think> to skip reasoning
|
||||
//
|
||||
// OLLAMA DEFAULTS:
|
||||
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
|
||||
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
|
||||
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
|
||||
// - Users can disable thinking per-turn via thinkValue=false
|
||||
type GLM47Renderer struct{}
|
||||
|
||||
func (r *GLM47Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(formatGLM47ToolJSON(d))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
||||
}
|
||||
|
||||
think := true
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
think = false
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>")
|
||||
sb.WriteString(message.Content)
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<tool_call>" + toolCall.Function.Name)
|
||||
sb.WriteString(renderGLM47ToolArguments(toolCall.Function.Arguments))
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("<tool_response>")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|assistant|>")
|
||||
if think {
|
||||
sb.WriteString("<think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func renderGLM47ToolArguments(args api.ToolCallFunctionArguments) string {
|
||||
var sb strings.Builder
|
||||
for key, value := range args.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func formatGLM47ToolJSON(raw []byte) string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(raw) + len(raw)/10)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := range raw {
|
||||
ch := raw[i]
|
||||
sb.WriteByte(ch)
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
inString = true
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ':' || ch == ',' {
|
||||
sb.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM47Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic user message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Hello<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "[gMASK]<sop><|user|>Hello<|assistant|></think>",
|
||||
},
|
||||
{
|
||||
name: "system and user",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>You are helpful.<|user|>Hello<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello there"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Hi<|assistant|></think>Hello there<|user|>How are you?<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "assistant with reasoning_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Answer with reasoning."},
|
||||
{Role: "assistant", Thinking: "Plan.", Content: "Done."},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Answer with reasoning.<|assistant|><think>Plan.</think>Done.<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "tool call with empty content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "tool call with content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
{Role: "assistant", Content: "It is 22C."},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think>Let me check<tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|></think>It is 22C.<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls and responses",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Compare weather"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo"}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Paris"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
{Role: "tool", Content: `{"temperature":18}`},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Compare weather<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><tool_call>get_weather<arg_key>location</arg_key><arg_value>Paris</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><tool_response>{\"temperature\":18}</tool_response><|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "preserved thinking in multi-turn",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Think step by step"},
|
||||
{Role: "assistant", Thinking: "Let me think...", Content: "Here's my answer."},
|
||||
{Role: "user", Content: "Continue"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Think step by step<|assistant|><think>Let me think...</think>Here's my answer.<|user|>Continue<|assistant|><think>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
renderer := &GLM47Renderer{}
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
t.Logf("Got:\n%s", rendered)
|
||||
t.Logf("Expected:\n%s", tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type LFM2Renderer struct {
|
||||
IsThinking bool
|
||||
}
|
||||
|
||||
func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
// Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
|
||||
|
||||
// Extract first system message if present (to combine with tools)
|
||||
var firstSystemContent string
|
||||
startIdx := 0
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
firstSystemContent = messages[0].Content
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
// Append tools to first system content
|
||||
if len(tools) > 0 {
|
||||
if firstSystemContent != "" {
|
||||
firstSystemContent += "\n"
|
||||
}
|
||||
firstSystemContent += "List of tools: ["
|
||||
for i, tool := range tools {
|
||||
toolJSON, err := json.Marshal(tool)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
firstSystemContent += string(toolJSON)
|
||||
if i < len(tools)-1 {
|
||||
firstSystemContent += ", "
|
||||
}
|
||||
}
|
||||
firstSystemContent += "]"
|
||||
}
|
||||
|
||||
// Output first system block if it has content
|
||||
if firstSystemContent != "" {
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(firstSystemContent)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
// Find the index of the last assistant message for thinking stripping
|
||||
lastAssistantIndex := -1
|
||||
for i := len(messages) - 1; i >= startIdx; i-- {
|
||||
if messages[i].Role == "assistant" {
|
||||
lastAssistantIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Track whether we need to add generation prompt
|
||||
needsGenerationPrompt := len(messages) > 0
|
||||
|
||||
for i := startIdx; i < len(messages); i++ {
|
||||
message := messages[i]
|
||||
switch message.Role {
|
||||
case "system":
|
||||
// Additional system messages (after the first) are rendered normally
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
|
||||
case "user":
|
||||
sb.WriteString("<|im_start|>user\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
|
||||
// Check if this is the last assistant message
|
||||
isLastAssistant := i == lastAssistantIndex
|
||||
|
||||
// Process content (may need thinking stripped)
|
||||
content := message.Content
|
||||
|
||||
// Handle thinking tags in assistant content
|
||||
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
|
||||
if strings.Contains(content, "</think>") {
|
||||
parts := strings.SplitN(content, "</think>", 2)
|
||||
if len(parts) > 1 {
|
||||
if !isLastAssistant && !keepPastThinking {
|
||||
// Strip thinking entirely for past assistant messages
|
||||
content = strings.TrimSpace(parts[1])
|
||||
} else {
|
||||
// Preserve thinking but trim whitespace after </think>
|
||||
content = parts[0] + "</think>" + strings.TrimLeft(parts[1], " \t\n\r")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
// Assistant with tool calls - write content first (if any after stripping)
|
||||
if content != "" {
|
||||
sb.WriteString(content)
|
||||
}
|
||||
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<|tool_call_start|>")
|
||||
toolCallJSON := map[string]any{
|
||||
"name": toolCall.Function.Name,
|
||||
"arguments": toolCall.Function.Arguments,
|
||||
}
|
||||
callJSON, _ := json.Marshal(toolCallJSON)
|
||||
sb.WriteString(string(callJSON))
|
||||
sb.WriteString("<|tool_call_end|>")
|
||||
}
|
||||
} else {
|
||||
sb.WriteString(content)
|
||||
}
|
||||
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
|
||||
|
||||
case "tool":
|
||||
// Tool responses are rendered as plain messages per the chat template
|
||||
sb.WriteString("<|im_start|>tool\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
needsGenerationPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt
|
||||
if needsGenerationPrompt {
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
// Note: Model is a "thinking-only" model - it will output <think> itself
|
||||
// We don't add <think> tag to the prompt
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -1,427 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestLFM2Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic user message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple system messages rendered separately",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "First instruction."},
|
||||
{Role: "system", Content: "Second instruction."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
{Role: "assistant", Content: "The answer is 4."},
|
||||
{Role: "user", Content: "Thanks!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "only system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
|
||||
name: "user-assistant-user: last assistant preserves thinking",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reasoning</think>A1"},
|
||||
{Role: "user", Content: "Q2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reasoning</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// With two assistants, first is stripped (not last), second preserved (is last)
|
||||
name: "multi-turn thinking: first stripped, second preserved",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reason1</think>A1"},
|
||||
{Role: "user", Content: "Q2"},
|
||||
{Role: "assistant", Content: "<think>reason2</think>A2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// With thinking enabled (keep_past_thinking=true), both preserved
|
||||
name: "multi-turn thinking: both preserved when thinking enabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Q1"},
|
||||
{Role: "assistant", Content: "<think>reason1</think>A1"},
|
||||
{Role: "user", Content: "Q2"},
|
||||
{Role: "assistant", Content: "<think>reason2</think>A2"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with content and tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tool response",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{Role: "assistant", Content: "Let me check."},
|
||||
{Role: "tool", Content: "22C, Sunny"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Get weather for Paris and London"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tools definitions with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "tools definitions without system message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
|
||||
},
|
||||
{
|
||||
name: "multiple tools without system message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_time",
|
||||
Description: "Get time",
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "user-tool sequence",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Check weather"},
|
||||
{Role: "tool", Content: "22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "full tool call cycle",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Check weather"},
|
||||
{Role: "assistant", Content: "Let me check"},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "It's 22C"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "unicode content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "你好世界! مرحبا 🌍"},
|
||||
{Role: "assistant", Content: "Hello! 👋"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "newlines in content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
|
||||
{Role: "assistant", Content: "Response with\nmultiple\nlines"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "empty assistant content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: ""},
|
||||
{Role: "user", Content: "OK"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Generation prompt does NOT include <think> - model outputs it
|
||||
name: "generation prompt has no think tag",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Think hard"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Interleaved: thinking before tool call - last assistant preserves thinking
|
||||
name: "thinking before tool call (last assistant)",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>I need to check the weather</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>I need to check the weather</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Two assistants with tool calls - first has thinking stripped
|
||||
name: "two assistants with tools: first thinking stripped",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>checking</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Two assistants with tools - both preserved when thinking enabled
|
||||
name: "two assistants with tools: both preserved when thinking enabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "<think>checking</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>checking</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
// Content before thinking before tool call
|
||||
name: "content then thinking then tool call",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.<think>Using weather API</think>",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<think>Using weather API</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
}
|
||||
|
||||
renderer := &LFM2Renderer{IsThinking: true}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
|
||||
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -80,12 +80,6 @@ func rendererForName(name string) Renderer {
|
||||
return &Nemotron3NanoRenderer{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaRenderer{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Renderer{}
|
||||
case "lfm2":
|
||||
return &LFM2Renderer{IsThinking: false}
|
||||
case "lfm2-thinking":
|
||||
return &LFM2Renderer{IsThinking: true}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,26 +1,6 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func args(s string) api.ToolCallFunctionArguments {
|
||||
var result api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in args(): " + err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func propsMap(s string) *api.ToolPropertiesMap {
|
||||
var result api.ToolPropertiesMap
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in propsMap(): " + err.Error())
|
||||
}
|
||||
return &result
|
||||
}
|
||||
import "github.com/ollama/ollama/api"
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -91,7 +90,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
oldManifest, _ := manifest.ParseNamedManifest(name)
|
||||
oldManifest, _ := ParseNamedManifest(name)
|
||||
|
||||
var baseLayers []*layerGGML
|
||||
var err error
|
||||
@@ -124,9 +123,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
|
||||
mf, mErr := manifest.ParseNamedManifest(fromName)
|
||||
if mErr == nil && mf.Config.Digest != "" {
|
||||
configPath, pErr := manifest.GetBlobsPath(mf.Config.Digest)
|
||||
manifest, mErr := ParseNamedManifest(fromName)
|
||||
if mErr == nil && manifest.Config.Digest != "" {
|
||||
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
|
||||
if pErr == nil {
|
||||
if cfgFile, fErr := os.Open(configPath); fErr == nil {
|
||||
var baseConfig model.ConfigV2
|
||||
@@ -343,7 +342,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
|
||||
return "gguf"
|
||||
} else {
|
||||
// try to see if we can find a gguf file even without the file extension
|
||||
blobPath, err := manifest.GetBlobsPath(files[fn])
|
||||
blobPath, err := GetBlobsPath(files[fn])
|
||||
if err != nil {
|
||||
slog.Error("error getting blobs path", "file", fn)
|
||||
return ""
|
||||
@@ -395,7 +394,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
||||
}
|
||||
|
||||
blobPath, err := manifest.GetBlobsPath(digest)
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -433,7 +432,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := manifest.NewLayer(t, mediaType)
|
||||
layer, err := NewLayer(t, mediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -466,7 +465,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||
var layers []manifest.Layer
|
||||
var layers []Layer
|
||||
for _, layer := range baseLayers {
|
||||
if layer.GGML != nil {
|
||||
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
||||
@@ -551,13 +550,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if layer.Status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.Status})
|
||||
if layer.status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.status})
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
|
||||
if err := WriteManifest(name, *configLayer, layers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -578,7 +577,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob, err := manifest.GetBlobsPath(layer.Digest)
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -600,7 +599,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
}
|
||||
temp.Seek(0, io.SeekStart)
|
||||
fn(api.ProgressResponse{Status: "verifying conversion"})
|
||||
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
|
||||
newLayer, err := NewLayer(temp, layer.MediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -620,7 +619,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
var layers []*layerGGML
|
||||
|
||||
fn(api.ProgressResponse{Status: "parsing GGUF"})
|
||||
blobPath, err := manifest.GetBlobsPath(digest)
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -655,7 +654,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
if err != nil {
|
||||
slog.Debug("could not create new layer from layer", "error", err)
|
||||
return nil, err
|
||||
@@ -666,8 +665,8 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
return detectChatTemplate(layers)
|
||||
}
|
||||
|
||||
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
|
||||
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
|
||||
func removeLayer(layers []Layer, mediatype string) []Layer {
|
||||
return slices.DeleteFunc(layers, func(layer Layer) bool {
|
||||
if layer.MediaType != mediatype {
|
||||
return false
|
||||
}
|
||||
@@ -681,7 +680,7 @@ func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
|
||||
})
|
||||
}
|
||||
|
||||
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
|
||||
func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.template")
|
||||
if _, err := template.Parse(t); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||
@@ -691,7 +690,7 @@ func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
|
||||
}
|
||||
|
||||
blob := strings.NewReader(t)
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -700,11 +699,11 @@ func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
|
||||
func setSystem(layers []Layer, s string) ([]Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.system")
|
||||
if s != "" {
|
||||
blob := strings.NewReader(s)
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -713,9 +712,9 @@ func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
|
||||
func setLicense(layers []Layer, l string) ([]Layer, error) {
|
||||
blob := strings.NewReader(l)
|
||||
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
|
||||
layer, err := NewLayer(blob, "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -723,7 +722,7 @@ func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
|
||||
func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
|
||||
if p == nil {
|
||||
p = make(map[string]any)
|
||||
}
|
||||
@@ -732,7 +731,7 @@ func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer,
|
||||
continue
|
||||
}
|
||||
|
||||
digestPath, err := manifest.GetBlobsPath(layer.Digest)
|
||||
digestPath, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -766,7 +765,7 @@ func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer,
|
||||
if err := json.NewEncoder(&b).Encode(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -774,7 +773,7 @@ func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer,
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
|
||||
func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
|
||||
// this leaves the old messages intact if no new messages were specified
|
||||
// which may not be the correct behaviour
|
||||
if len(m) == 0 {
|
||||
@@ -787,7 +786,7 @@ func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, er
|
||||
if err := json.NewEncoder(&b).Encode(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -795,7 +794,7 @@ func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, er
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
|
||||
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
|
||||
digests := make([]string, len(layers))
|
||||
for i, layer := range layers {
|
||||
digests[i] = layer.Digest
|
||||
@@ -806,7 +805,7 @@ func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifes
|
||||
if err := json.NewEncoder(&b).Encode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
)
|
||||
|
||||
func TestConvertFromSafetensors(t *testing.T) {
|
||||
@@ -18,7 +17,7 @@ func TestConvertFromSafetensors(t *testing.T) {
|
||||
|
||||
// Helper function to create a new layer and return its digest
|
||||
makeTemp := func(content string) string {
|
||||
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create layer: %v", err)
|
||||
}
|
||||
|
||||
@@ -24,8 +24,6 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const maxRetries = 6
|
||||
@@ -458,7 +456,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
}
|
||||
|
||||
type downloadOpts struct {
|
||||
n model.Name
|
||||
mp ModelPath
|
||||
digest string
|
||||
regOpts *registryOptions
|
||||
fn func(api.ProgressResponse)
|
||||
@@ -467,10 +465,10 @@ type downloadOpts struct {
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||
if opts.digest == "" {
|
||||
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
|
||||
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
|
||||
}
|
||||
|
||||
fp, err := manifest.GetBlobsPath(opts.digest)
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -494,8 +492,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
|
||||
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
||||
download := data.(*blobDownload)
|
||||
if !ok {
|
||||
requestURL := opts.n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
|
||||
requestURL := opts.mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
|
||||
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
||||
blobDownloadManager.Delete(opts.digest)
|
||||
return false, err
|
||||
|
||||
205
server/images.go
205
server/images.go
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -23,7 +24,6 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -274,22 +274,44 @@ func (m *Model) String() string {
|
||||
return modelfile.String()
|
||||
}
|
||||
|
||||
func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
f, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
sha256sum := sha256.New()
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func GetModel(name string) (*Model, error) {
|
||||
n := model.ParseName(name)
|
||||
mf, err := manifest.ParseNamedManifest(n)
|
||||
mp := ParseModelPath(name)
|
||||
manifest, digest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Name: n.String(),
|
||||
ShortName: n.DisplayShortest(),
|
||||
Digest: mf.Digest(),
|
||||
model := &Model{
|
||||
Name: mp.GetFullTagname(),
|
||||
ShortName: mp.GetShortTagname(),
|
||||
Digest: digest,
|
||||
Template: template.DefaultTemplate,
|
||||
}
|
||||
|
||||
if mf.Config.Digest != "" {
|
||||
filename, err := manifest.GetBlobsPath(mf.Config.Digest)
|
||||
if manifest.Config.Digest != "" {
|
||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -300,29 +322,29 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer configFile.Close()
|
||||
|
||||
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
|
||||
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, layer := range mf.Layers {
|
||||
filename, err := manifest.GetBlobsPath(layer.Digest)
|
||||
for _, layer := range manifest.Layers {
|
||||
filename, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model":
|
||||
m.ModelPath = filename
|
||||
m.ParentModel = layer.From
|
||||
model.ModelPath = filename
|
||||
model.ParentModel = layer.From
|
||||
case "application/vnd.ollama.image.embed":
|
||||
// Deprecated in versions > 0.1.2
|
||||
// TODO: remove this warning in a future version
|
||||
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||
case "application/vnd.ollama.image.adapter":
|
||||
m.AdapterPaths = append(m.AdapterPaths, filename)
|
||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||
case "application/vnd.ollama.image.projector":
|
||||
m.ProjectorPaths = append(m.ProjectorPaths, filename)
|
||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
||||
case "application/vnd.ollama.image.prompt",
|
||||
"application/vnd.ollama.image.template":
|
||||
bts, err := os.ReadFile(filename)
|
||||
@@ -330,7 +352,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.Template, err = template.Parse(string(bts))
|
||||
model.Template, err = template.Parse(string(bts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -340,7 +362,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.System = string(bts)
|
||||
model.System = string(bts)
|
||||
case "application/vnd.ollama.image.params":
|
||||
params, err := os.Open(filename)
|
||||
if err != nil {
|
||||
@@ -349,7 +371,7 @@ func GetModel(name string) (*Model, error) {
|
||||
defer params.Close()
|
||||
|
||||
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
||||
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
|
||||
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.messages":
|
||||
@@ -359,7 +381,7 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer msgs.Close()
|
||||
|
||||
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
|
||||
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.license":
|
||||
@@ -367,11 +389,11 @@ func GetModel(name string) (*Model, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.License = append(m.License, string(bts))
|
||||
model.License = append(model.License, string(bts))
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func CopyModel(src, dst model.Name) error {
|
||||
@@ -386,7 +408,7 @@ func CopyModel(src, dst model.Name) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
manifests, err := manifest.GetManifestPath()
|
||||
manifests, err := GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -415,7 +437,7 @@ func CopyModel(src, dst model.Name) error {
|
||||
|
||||
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
|
||||
manifests, err := manifest.Manifests(true)
|
||||
manifests, err := Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -430,7 +452,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
|
||||
// only delete the files which are still in the deleteMap
|
||||
for k := range deleteMap {
|
||||
fp, err := manifest.GetBlobsPath(k)
|
||||
fp, err := GetBlobsPath(k)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
|
||||
continue
|
||||
@@ -446,7 +468,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
|
||||
func PruneLayers() error {
|
||||
deleteMap := make(map[string]struct{})
|
||||
p, err := manifest.GetBlobsPath("")
|
||||
p, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -461,9 +483,9 @@ func PruneLayers() error {
|
||||
name := blob.Name()
|
||||
name = strings.ReplaceAll(name, "-", ":")
|
||||
|
||||
_, err := manifest.GetBlobsPath(name)
|
||||
_, err := GetBlobsPath(name)
|
||||
if err != nil {
|
||||
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
|
||||
if errors.Is(err, ErrInvalidDigestFormat) {
|
||||
// remove invalid blobs (e.g. partial downloads)
|
||||
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
|
||||
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
|
||||
@@ -488,30 +510,63 @@ func PruneLayers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func PruneDirectory(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
entries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
n := model.ParseName(name)
|
||||
mp := ParseModelPath(name)
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
mf, err := manifest.ParseNamedManifest(n)
|
||||
manifest, _, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
|
||||
return err
|
||||
}
|
||||
|
||||
var layers []manifest.Layer
|
||||
layers = append(layers, mf.Layers...)
|
||||
if mf.Config.Digest != "" {
|
||||
layers = append(layers, mf.Config)
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
// Read raw manifest JSON to preserve tensor metadata fields
|
||||
manifestPath, err := manifest.GetManifestPathForName(n)
|
||||
manifestPath, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -519,7 +574,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pushWithTransfer(ctx, n, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
@@ -527,17 +582,17 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, n, layer, regOpts, fn); err != nil {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pushing manifest"})
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -556,44 +611,44 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
n := model.ParseName(name)
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
// build deleteMap to prune unused layers
|
||||
deleteMap := make(map[string]struct{})
|
||||
existingMf, err := manifest.ParseNamedManifest(n)
|
||||
manifest, _, err := GetManifest(mp)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// noop
|
||||
} else if err != nil {
|
||||
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
|
||||
} else {
|
||||
for _, l := range existingMf.Layers {
|
||||
for _, l := range manifest.Layers {
|
||||
deleteMap[l.Digest] = struct{}{}
|
||||
}
|
||||
if existingMf.Config.Digest != "" {
|
||||
deleteMap[existingMf.Config.Digest] = struct{}{}
|
||||
if manifest.Config.Digest != "" {
|
||||
deleteMap[manifest.Config.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
mf, err := pullModelManifest(ctx, n, regOpts)
|
||||
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pull model manifest: %s", err)
|
||||
}
|
||||
|
||||
var layers []manifest.Layer
|
||||
layers = append(layers, mf.Layers...)
|
||||
if mf.Config.Digest != "" {
|
||||
layers = append(layers, mf.Config)
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
if err := pullWithTransfer(ctx, n, layers, mf, regOpts, fn); err != nil {
|
||||
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "success"})
|
||||
@@ -603,7 +658,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
n: n,
|
||||
mp: mp,
|
||||
digest: layer.Digest,
|
||||
regOpts: regOpts,
|
||||
fn: fn,
|
||||
@@ -622,7 +677,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
fp, err := manifest.GetBlobsPath(layer.Digest)
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -637,16 +692,16 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
for _, layer := range layers {
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, mf.Config.Digest)
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := manifest.GetManifestPathForName(n)
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -673,9 +728,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
// hasTensorLayers checks if any layer has tensor media type.
|
||||
func hasTensorLayers(layers []manifest.Layer) bool {
|
||||
func hasTensorLayers(layers []Layer) bool {
|
||||
for _, layer := range layers {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
if layer.MediaType == MediaTypeImageTensor {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -683,7 +738,7 @@ func hasTensorLayers(layers []manifest.Layer) bool {
|
||||
}
|
||||
|
||||
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
|
||||
func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, mf *manifest.Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
@@ -692,12 +747,12 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
|
||||
}
|
||||
}
|
||||
|
||||
destDir, err := manifest.GetBlobsPath("")
|
||||
destDir, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := n.BaseURL()
|
||||
base := mp.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
@@ -729,7 +784,7 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
|
||||
Blobs: blobs,
|
||||
BaseURL: baseURL,
|
||||
DestDir: destDir,
|
||||
Repository: n.DisplayNamespaceModel(),
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
Progress: progress,
|
||||
Token: regOpts.Token,
|
||||
GetToken: getToken,
|
||||
@@ -740,12 +795,12 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
|
||||
|
||||
// Write manifest
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
manifestJSON, err := json.Marshal(mf)
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := manifest.GetManifestPathForName(n)
|
||||
fp, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -757,7 +812,7 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
|
||||
}
|
||||
|
||||
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
|
||||
func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
blobs := make([]transfer.Blob, len(layers))
|
||||
for i, layer := range layers {
|
||||
blobs[i] = transfer.Blob{
|
||||
@@ -767,12 +822,12 @@ func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
|
||||
}
|
||||
}
|
||||
|
||||
srcDir, err := manifest.GetBlobsPath("")
|
||||
srcDir, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
base := n.BaseURL()
|
||||
base := mp.BaseURL()
|
||||
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
base.Scheme = "http"
|
||||
}
|
||||
@@ -809,13 +864,13 @@ func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
|
||||
GetToken: getToken,
|
||||
Logger: slog.Default(),
|
||||
Manifest: manifestJSON,
|
||||
ManifestRef: n.Tag,
|
||||
Repository: n.DisplayNamespaceModel(),
|
||||
ManifestRef: mp.Tag,
|
||||
Repository: mp.GetNamespaceRepository(),
|
||||
})
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
|
||||
requestURL := n.BaseURL().JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
@@ -825,7 +880,7 @@ func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptio
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var m manifest.Manifest
|
||||
var m Manifest
|
||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -987,7 +1042,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
|
||||
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
|
||||
|
||||
func verifyBlob(digest string) error {
|
||||
fp, err := manifest.GetBlobsPath(digest)
|
||||
fp, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package manifest
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
@@ -14,7 +14,7 @@ type Layer struct {
|
||||
Size int64 `json:"size"`
|
||||
From string `json:"from,omitempty"`
|
||||
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
|
||||
Status string `json:"-"`
|
||||
status string
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -65,7 +65,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
MediaType: mediatype,
|
||||
Digest: digest,
|
||||
Size: n,
|
||||
Status: fmt.Sprintf("%s %s", status, digest),
|
||||
status: fmt.Sprintf("%s %s", status, digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
Digest: digest,
|
||||
Size: fi.Size(),
|
||||
From: from,
|
||||
Status: fmt.Sprintf("using existing layer %s", digest),
|
||||
status: fmt.Sprintf("using existing layer %s", digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package manifest
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -32,32 +33,6 @@ func (m *Manifest) Size() (size int64) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Manifest) Digest() string {
|
||||
return m.digest
|
||||
}
|
||||
|
||||
func (m *Manifest) FileInfo() os.FileInfo {
|
||||
return m.fi
|
||||
}
|
||||
|
||||
// ReadConfigJSON reads and unmarshals a config layer as JSON.
|
||||
func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
|
||||
for _, layer := range m.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
|
||||
blobPath, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("config %q not found in manifest", configPath)
|
||||
}
|
||||
|
||||
func (m *Manifest) Remove() error {
|
||||
if err := os.Remove(m.filepath); err != nil {
|
||||
return err
|
||||
@@ -99,7 +74,7 @@ func (m *Manifest) RemoveLayers() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(blob); os.IsNotExist(err) {
|
||||
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
@@ -1,4 +1,4 @@
|
||||
package manifest
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -21,19 +20,19 @@ import (
|
||||
var intermediateBlobs map[string]string = make(map[string]string)
|
||||
|
||||
type layerGGML struct {
|
||||
manifest.Layer
|
||||
Layer
|
||||
*ggml.GGML
|
||||
}
|
||||
|
||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||
m, err := manifest.ParseNamedManifest(name)
|
||||
m, err := ParseNamedManifest(name)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err = manifest.ParseNamedManifest(name)
|
||||
m, err = ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -42,7 +41,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -51,7 +50,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
case "application/vnd.ollama.image.model",
|
||||
"application/vnd.ollama.image.projector",
|
||||
"application/vnd.ollama.image.adapter":
|
||||
blobpath, err := manifest.GetBlobsPath(layer.Digest)
|
||||
blobpath, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -82,12 +81,12 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
if t, err := template.Named(s); err != nil {
|
||||
slog.Debug("template detection", "error", err, "template", s)
|
||||
} else {
|
||||
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
|
||||
if t.Parameters != nil {
|
||||
@@ -96,7 +95,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
146
server/modelpath.go
Normal file
146
server/modelpath.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type ModelPath struct {
|
||||
ProtocolScheme string
|
||||
Registry string
|
||||
Namespace string
|
||||
Repository string
|
||||
Tag string
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultRegistry = "registry.ollama.ai"
|
||||
DefaultNamespace = "library"
|
||||
DefaultTag = "latest"
|
||||
DefaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
ErrModelPathInvalid = errors.New("invalid model path")
|
||||
)
|
||||
|
||||
func ParseModelPath(name string) ModelPath {
|
||||
mp := ModelPath{
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "",
|
||||
Tag: DefaultTag,
|
||||
}
|
||||
|
||||
before, after, found := strings.Cut(name, "://")
|
||||
if found {
|
||||
mp.ProtocolScheme = before
|
||||
name = after
|
||||
}
|
||||
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
mp.Registry = parts[0]
|
||||
mp.Namespace = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
case 2:
|
||||
mp.Namespace = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
case 1:
|
||||
mp.Repository = parts[0]
|
||||
}
|
||||
|
||||
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
|
||||
mp.Repository = repo
|
||||
mp.Tag = tag
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetFullTagname() string {
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetShortTagname() string {
|
||||
if mp.Registry == DefaultRegistry {
|
||||
if mp.Namespace == DefaultNamespace {
|
||||
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||
}
|
||||
|
||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||
func (mp ModelPath) GetManifestPath() (string, error) {
|
||||
name := model.Name{
|
||||
Host: mp.Registry,
|
||||
Namespace: mp.Namespace,
|
||||
Model: mp.Repository,
|
||||
Tag: mp.Tag,
|
||||
}
|
||||
if !name.IsValid() {
|
||||
return "", fs.ErrNotExist
|
||||
}
|
||||
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
|
||||
}
|
||||
|
||||
func (mp ModelPath) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: mp.ProtocolScheme,
|
||||
Host: mp.Registry,
|
||||
}
|
||||
}
|
||||
|
||||
func GetManifestPath() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func GetBlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(envconfig.Models(), "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
if digest == "" {
|
||||
dirPath = path
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
153
server/modelpath_test.go
Normal file
153
server/modelpath_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetBlobsPath(t *testing.T) {
|
||||
// GetBlobsPath expects an actual directory to exist
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
digest string
|
||||
expected string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"empty digest",
|
||||
"",
|
||||
filepath.Join(tempDir, "blobs"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with colon",
|
||||
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with dash",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"digest too short",
|
||||
"sha256-45640291",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest too long",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest invalid chars",
|
||||
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
got, err := GetBlobsPath(tc.digest)
|
||||
|
||||
require.ErrorIs(t, tc.err, err, tc.name)
|
||||
assert.Equal(t, tc.expected, got, tc.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg string
|
||||
want ModelPath
|
||||
}{
|
||||
{
|
||||
"full path https",
|
||||
"https://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"full path http",
|
||||
"http://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "http",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no protocol",
|
||||
"example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no registry",
|
||||
"ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no namespace",
|
||||
"repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no tag",
|
||||
"repo",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: DefaultTag,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ParseModelPath(tc.arg)
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got: %q want: %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -198,8 +198,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
name := t.Name
|
||||
quantize := strings.HasSuffix(name, "weight")
|
||||
|
||||
// don't quantize vision encoder tensors (named with "v." prefix)
|
||||
quantize = quantize && !strings.HasPrefix(name, "v.")
|
||||
// don't quantize vision stuff
|
||||
quantize = quantize && (!strings.Contains(name, "v.") || strings.Contains(name, "_v."))
|
||||
quantize = quantize && !strings.Contains(name, "mm.")
|
||||
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
@@ -219,9 +219,6 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||
quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight")
|
||||
|
||||
// do not quantize LFM2's shortconv kernel weights
|
||||
quantize = quantize && !strings.Contains(name, "shortconv.conv.weight")
|
||||
|
||||
// do not quantize RWKV's time_mix_first tensors
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_first.weight")
|
||||
quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")
|
||||
|
||||
@@ -39,7 +39,6 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/middleware"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
@@ -221,6 +220,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle image generation models
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
s.handleImageGenerate(c, req, name.String(), checkpointStart)
|
||||
return
|
||||
}
|
||||
|
||||
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
||||
return
|
||||
@@ -316,7 +321,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner if unload is requested (empty prompt, keep alive is 0)
|
||||
// expire the runner
|
||||
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
|
||||
@@ -330,12 +335,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle image generation models
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
s.handleImageGenerate(c, req, name.String(), checkpointStart)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||
return
|
||||
@@ -975,7 +974,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
// is.
|
||||
func getExistingName(n model.Name) (model.Name, error) {
|
||||
var zero model.Name
|
||||
existing, err := manifest.Manifests(true)
|
||||
existing, err := Manifests(true)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
@@ -1019,7 +1018,7 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
m, err := manifest.ParseNamedManifest(n)
|
||||
m, err := ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
@@ -1081,7 +1080,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
return nil, model.Unqualified(name)
|
||||
return nil, ErrModelPathInvalid
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
@@ -1113,7 +1112,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For safetensors LLM models (experimental), populate details from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
||||
modelDetails.Family = arch
|
||||
}
|
||||
@@ -1122,7 +1121,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
}
|
||||
// Get torch_dtype directly from config.json for quantization level
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
|
||||
modelDetails.QuantizationLevel = dtype
|
||||
}
|
||||
}
|
||||
@@ -1136,7 +1135,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||
}
|
||||
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
manifest, err := ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1148,11 +1147,8 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
Capabilities: m.Capabilities(),
|
||||
ModifiedAt: mf.FileInfo().ModTime(),
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
Requires: m.Config.Requires,
|
||||
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
|
||||
// default we return an empty map.
|
||||
ModelInfo: make(map[string]any),
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" {
|
||||
@@ -1215,7 +1211,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
@@ -1224,12 +1220,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
|
||||
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
resp.ModelInfo = info
|
||||
}
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
@@ -1286,7 +1282,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
|
||||
}
|
||||
|
||||
func (s *Server) ListHandler(c *gin.Context) {
|
||||
ms, err := manifest.Manifests(true)
|
||||
ms, err := Manifests(true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1317,8 +1313,8 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
RemoteModel: cf.RemoteModel,
|
||||
RemoteHost: cf.RemoteHost,
|
||||
Size: m.Size(),
|
||||
Digest: m.Digest(),
|
||||
ModifiedAt: m.FileInfo().ModTime(),
|
||||
Digest: m.digest,
|
||||
ModifiedAt: m.fi.ModTime(),
|
||||
Details: api.ModelDetails{
|
||||
Format: cf.ModelFormat,
|
||||
Family: cf.ModelFamily,
|
||||
@@ -1377,7 +1373,7 @@ func (s *Server) CopyHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||
path, err := manifest.GetBlobsPath(c.Param("digest"))
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1393,7 +1389,7 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||
|
||||
func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
|
||||
p, err := manifest.GetBlobsPath(ib)
|
||||
p, err := GetBlobsPath(ib)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1411,7 +1407,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
path, err := manifest.GetBlobsPath(c.Param("digest"))
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1429,7 +1425,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
layer, err := manifest.NewLayer(c.Request.Body, "")
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1629,7 +1625,7 @@ func Serve(ln net.Listener) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
slog.Info("server config", "env", envconfig.Values())
|
||||
|
||||
blobsDir, err := manifest.GetBlobsPath("")
|
||||
blobsDir, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1638,7 +1634,7 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
|
||||
if !envconfig.NoPrune() {
|
||||
if _, err := manifest.Manifests(false); err != nil {
|
||||
if _, err := Manifests(false); err != nil {
|
||||
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
|
||||
} else {
|
||||
// clean up unused layers and manifests
|
||||
@@ -1646,12 +1642,12 @@ func Serve(ln net.Listener) error {
|
||||
return err
|
||||
}
|
||||
|
||||
manifestsPath, err := manifest.GetManifestPath()
|
||||
manifestsPath, err := GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manifest.PruneDirectory(manifestsPath); err != nil {
|
||||
if err := PruneDirectory(manifestsPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -224,15 +223,15 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
mf, err := manifest.ParseNamedManifest(model.ParseName("child"))
|
||||
manifest, err := ParseNamedManifest(model.ParseName("child"))
|
||||
if err != nil {
|
||||
t.Fatalf("parse manifest: %v", err)
|
||||
}
|
||||
if mf.Config.Digest == "" {
|
||||
if manifest.Config.Digest == "" {
|
||||
t.Fatalf("unexpected empty config digest for child manifest")
|
||||
}
|
||||
|
||||
configPath, err := manifest.GetBlobsPath(mf.Config.Digest)
|
||||
configPath, err := GetBlobsPath(manifest.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("config blob path: %v", err)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -94,13 +93,13 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create a manifest with duplicate layers
|
||||
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
|
||||
if err := WriteManifest(n, config, []Layer{config}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -2101,95 +2101,3 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateUnload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var loadFnCalled bool
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mockRunner{}),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
loadFnCalled = true
|
||||
req.successCh <- &runnerRef{llama: &mockRunner{}}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("unload with empty prompt and keepalive 0", func(t *testing.T) {
|
||||
loadFnCalled = false
|
||||
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "",
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.GenerateResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.DoneReason != "unload" {
|
||||
t.Errorf("expected done_reason 'unload', got %q", resp.DoneReason)
|
||||
}
|
||||
|
||||
if !resp.Done {
|
||||
t.Error("expected done to be true")
|
||||
}
|
||||
|
||||
if loadFnCalled {
|
||||
t.Error("expected model NOT to be loaded for unload request, but loadFn was called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -571,7 +571,6 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
totalSize: server.TotalSize(),
|
||||
|
||||
@@ -21,14 +21,12 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var blobUploadManager sync.Map
|
||||
|
||||
type blobUpload struct {
|
||||
manifest.Layer
|
||||
Layer
|
||||
|
||||
Total int64
|
||||
Completed atomic.Int64
|
||||
@@ -53,7 +51,7 @@ const (
|
||||
)
|
||||
|
||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
p, err := manifest.GetBlobsPath(b.Digest)
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -61,7 +59,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
|
||||
if b.From != "" {
|
||||
values := requestURL.Query()
|
||||
values.Add("mount", b.Digest)
|
||||
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
|
||||
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
|
||||
requestURL.RawQuery = values.Encode()
|
||||
}
|
||||
|
||||
@@ -130,7 +128,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||
defer blobUploadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
p, err := manifest.GetBlobsPath(b.Digest)
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
@@ -366,9 +364,9 @@ func (p *progressWriter) Rollback() {
|
||||
p.written = 0
|
||||
}
|
||||
|
||||
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
|
||||
switch {
|
||||
@@ -390,8 +388,8 @@ func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *r
|
||||
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
|
||||
upload := data.(*blobUpload)
|
||||
if !ok {
|
||||
requestURL := n.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
||||
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
|
||||
blobUploadManager.Delete(layer.Digest)
|
||||
return err
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
@@ -36,25 +35,22 @@ func Unqualified(n Name) error {
|
||||
const MissingPart = "!MISSING!"
|
||||
|
||||
const (
|
||||
defaultHost = "registry.ollama.ai"
|
||||
defaultNamespace = "library"
|
||||
defaultTag = "latest"
|
||||
defaultProtocolScheme = "https"
|
||||
defaultHost = "registry.ollama.ai"
|
||||
defaultNamespace = "library"
|
||||
defaultTag = "latest"
|
||||
)
|
||||
|
||||
// DefaultName returns a name with the default values for the host, namespace,
|
||||
// tag, and protocol scheme parts. The model and digest parts are empty.
|
||||
// and tag parts. The model and digest parts are empty.
|
||||
//
|
||||
// - The default host is ("registry.ollama.ai")
|
||||
// - The default namespace is ("library")
|
||||
// - The default tag is ("latest")
|
||||
// - The default protocol scheme is ("https")
|
||||
func DefaultName() Name {
|
||||
return Name{
|
||||
Host: defaultHost,
|
||||
Namespace: defaultNamespace,
|
||||
Tag: defaultTag,
|
||||
ProtocolScheme: defaultProtocolScheme,
|
||||
Host: defaultHost,
|
||||
Namespace: defaultNamespace,
|
||||
Tag: defaultTag,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,11 +87,10 @@ func (k partKind) String() string {
|
||||
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
|
||||
// is valid.
|
||||
type Name struct {
|
||||
Host string
|
||||
Namespace string
|
||||
Model string
|
||||
Tag string
|
||||
ProtocolScheme string
|
||||
Host string
|
||||
Namespace string
|
||||
Model string
|
||||
Tag string
|
||||
}
|
||||
|
||||
// ParseName parses and assembles a Name from a name string. The
|
||||
@@ -165,9 +160,7 @@ func ParseNameBare(s string) Name {
|
||||
}
|
||||
|
||||
scheme, host, ok := strings.Cut(s, "://")
|
||||
if ok {
|
||||
n.ProtocolScheme = scheme
|
||||
} else {
|
||||
if !ok {
|
||||
host = scheme
|
||||
}
|
||||
n.Host = host
|
||||
@@ -196,13 +189,12 @@ func ParseNameFromFilepath(s string) (n Name) {
|
||||
return n
|
||||
}
|
||||
|
||||
// Merge merges the host, namespace, tag, and protocol scheme parts of the two names,
|
||||
// Merge merges the host, namespace, and tag parts of the two names,
|
||||
// preferring the non-empty parts of a.
|
||||
func Merge(a, b Name) Name {
|
||||
a.Host = cmp.Or(a.Host, b.Host)
|
||||
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
|
||||
a.Tag = cmp.Or(a.Tag, b.Tag)
|
||||
a.ProtocolScheme = cmp.Or(a.ProtocolScheme, b.ProtocolScheme)
|
||||
return a
|
||||
}
|
||||
|
||||
@@ -313,23 +305,6 @@ func (n Name) EqualFold(o Name) bool {
|
||||
strings.EqualFold(n.Tag, o.Tag)
|
||||
}
|
||||
|
||||
// BaseURL returns the base URL for the registry.
|
||||
func (n Name) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: n.ProtocolScheme,
|
||||
Host: n.Host,
|
||||
}
|
||||
}
|
||||
|
||||
// DisplayNamespaceModel returns the namespace and model joined by "/".
|
||||
func (n Name) DisplayNamespaceModel() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(n.Namespace)
|
||||
b.WriteByte('/')
|
||||
b.WriteString(n.Model)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isValidLen(kind partKind, s string) bool {
|
||||
switch kind {
|
||||
case kindHost:
|
||||
|
||||
@@ -32,11 +32,10 @@ func TestParseNameParts(t *testing.T) {
|
||||
{
|
||||
in: "scheme://host:port/namespace/model:tag",
|
||||
want: Name{
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
ProtocolScheme: "scheme",
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
},
|
||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
|
||||
},
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
@@ -103,7 +103,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
||||
func newLayerCreator() create.LayerCreator {
|
||||
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
||||
layer, err := manifest.NewLayer(r, mediaType)
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
@@ -141,13 +141,13 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -169,7 +169,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -186,7 +186,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
||||
layer, err := manifest.NewLayer(r, manifest.MediaTypeImageTensor)
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -221,15 +221,15 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := manifest.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to manifest.Layer
|
||||
manifestLayers := make([]manifest.Layer, 0, len(layers))
|
||||
// Convert LayerInfo to server.Layer
|
||||
serverLayers := make([]server.Layer, 0, len(layers))
|
||||
for _, l := range layers {
|
||||
manifestLayers = append(manifestLayers, manifest.Layer{
|
||||
serverLayers = append(serverLayers, server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
@@ -243,19 +243,19 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manifestLayers = append(manifestLayers, modelfileLayers...)
|
||||
serverLayers = append(serverLayers, modelfileLayers...)
|
||||
}
|
||||
|
||||
return manifest.WriteManifest(name, configLayer, manifestLayers)
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
var layers []manifest.Layer
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
var layers []server.Layer
|
||||
|
||||
if mf.Template != "" {
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
||||
}
|
||||
@@ -263,7 +263,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
}
|
||||
|
||||
if mf.System != "" {
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
||||
}
|
||||
@@ -271,7 +271,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
}
|
||||
|
||||
if mf.License != "" {
|
||||
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
// Supported quantization types: "fp8" (affine 8-bit)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
// Lazy init MLX when needed for quantization
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("MLX initialization failed: %w", err)
|
||||
}
|
||||
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||
@@ -54,9 +59,6 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "fp4":
|
||||
// affine mode: group_size=32, bits=4
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
|
||||
case "fp8":
|
||||
// affine mode: group_size=32, bits=8
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
|
||||
|
||||
@@ -20,10 +20,10 @@ import (
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "fp4", "fp8":
|
||||
case "", "fp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
|
||||
@@ -7,17 +7,12 @@ import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
@@ -51,8 +46,8 @@ func main() {
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 0, "Image width (0 = auto from input or 1024)")
|
||||
height := flag.Int("height", 0, "Image height (0 = auto from input or 1024)")
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
@@ -66,7 +61,6 @@ func main() {
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
@@ -128,44 +122,6 @@ func main() {
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *flux2Flag:
|
||||
m := &flux2.Model{}
|
||||
if loadErr := m.Load(*modelPath); loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// Load input images with EXIF orientation correction
|
||||
var loadedImages []image.Image
|
||||
for _, path := range inputImages {
|
||||
img, loadErr := loadImageWithEXIF(path)
|
||||
if loadErr != nil {
|
||||
log.Fatalf("Failed to load image %s: %v", path, loadErr)
|
||||
}
|
||||
loadedImages = append(loadedImages, img)
|
||||
}
|
||||
// When input images provided and user didn't override dimensions, use 0 to match input
|
||||
fluxWidth := int32(*width)
|
||||
fluxHeight := int32(*height)
|
||||
if len(loadedImages) > 0 && *width == 0 && *height == 0 {
|
||||
// Both unset, will auto-detect from input
|
||||
} else if len(loadedImages) > 0 && *width == 0 {
|
||||
fluxWidth = 0 // Compute from height + aspect ratio
|
||||
} else if len(loadedImages) > 0 && *height == 0 {
|
||||
fluxHeight = 0 // Compute from width + aspect ratio
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &flux2.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: fluxWidth,
|
||||
Height: fluxHeight,
|
||||
Steps: *steps,
|
||||
GuidanceScale: float32(*cfgScale),
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
InputImages: loadedImages,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImage:
|
||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
@@ -320,8 +276,6 @@ func detectModelKind(modelPath string) (string, error) {
|
||||
switch index.ClassName {
|
||||
case "FluxPipeline", "ZImagePipeline":
|
||||
return "zimage", nil
|
||||
case "Flux2KleinPipeline":
|
||||
return "flux2", nil
|
||||
}
|
||||
}
|
||||
return "zimage", nil
|
||||
@@ -342,12 +296,3 @@ func detectModelKind(modelPath string) (string, error) {
|
||||
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
|
||||
// loadImageWithEXIF loads an image from a file path with EXIF orientation correction.
|
||||
func loadImageWithEXIF(path string) (image.Image, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
return imagegen.DecodeImage(data)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -109,160 +108,3 @@ func clampF(v, min, max float32) float32 {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// DecodeImage decodes image bytes with EXIF orientation applied.
|
||||
func DecodeImage(data []byte) (image.Image, error) {
|
||||
orientation := readJPEGOrientation(data)
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return applyOrientation(img, orientation), nil
|
||||
}
|
||||
|
||||
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
|
||||
// Returns 1 (normal) for non-JPEG or if orientation not found.
|
||||
func readJPEGOrientation(data []byte) int {
|
||||
if len(data) < 2 || data[0] != 0xFF || data[1] != 0xD8 {
|
||||
return 1 // Not JPEG
|
||||
}
|
||||
|
||||
r := bytes.NewReader(data[2:])
|
||||
for {
|
||||
var marker [2]byte
|
||||
if _, err := r.Read(marker[:]); err != nil || marker[0] != 0xFF {
|
||||
return 1
|
||||
}
|
||||
|
||||
if marker[1] == 0xE1 { // APP1 (EXIF)
|
||||
var lenBytes [2]byte
|
||||
if _, err := r.Read(lenBytes[:]); err != nil {
|
||||
return 1
|
||||
}
|
||||
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
|
||||
if segLen < 14 {
|
||||
r.Seek(int64(segLen), 1)
|
||||
continue
|
||||
}
|
||||
seg := make([]byte, segLen)
|
||||
if _, err := r.Read(seg); err != nil {
|
||||
return 1
|
||||
}
|
||||
if string(seg[:4]) == "Exif" && seg[4] == 0 && seg[5] == 0 {
|
||||
return parseTIFFOrientation(seg[6:])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if marker[1] == 0xD9 || marker[1] == 0xDA {
|
||||
return 1 // EOI or SOS
|
||||
}
|
||||
if marker[1] >= 0xD0 && marker[1] <= 0xD7 {
|
||||
continue // RST markers
|
||||
}
|
||||
|
||||
var lenBytes [2]byte
|
||||
if _, err := r.Read(lenBytes[:]); err != nil {
|
||||
return 1
|
||||
}
|
||||
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
|
||||
if segLen > 0 {
|
||||
r.Seek(int64(segLen), 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseTIFFOrientation(tiff []byte) int {
|
||||
if len(tiff) < 8 {
|
||||
return 1
|
||||
}
|
||||
|
||||
var big bool
|
||||
switch string(tiff[:2]) {
|
||||
case "MM":
|
||||
big = true
|
||||
case "II":
|
||||
big = false
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
|
||||
u16 := func(b []byte) uint16 {
|
||||
if big {
|
||||
return uint16(b[0])<<8 | uint16(b[1])
|
||||
}
|
||||
return uint16(b[1])<<8 | uint16(b[0])
|
||||
}
|
||||
u32 := func(b []byte) uint32 {
|
||||
if big {
|
||||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||
}
|
||||
return uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
|
||||
}
|
||||
|
||||
if u16(tiff[2:4]) != 42 {
|
||||
return 1
|
||||
}
|
||||
|
||||
ifdOffset := u32(tiff[4:8])
|
||||
if int(ifdOffset)+2 > len(tiff) {
|
||||
return 1
|
||||
}
|
||||
|
||||
numEntries := u16(tiff[ifdOffset : ifdOffset+2])
|
||||
for i := range int(numEntries) {
|
||||
offset := ifdOffset + 2 + uint32(i)*12
|
||||
if int(offset)+12 > len(tiff) {
|
||||
break
|
||||
}
|
||||
if u16(tiff[offset:offset+2]) == 0x0112 { // Orientation tag
|
||||
o := int(u16(tiff[offset+8 : offset+10]))
|
||||
if o >= 1 && o <= 8 {
|
||||
return o
|
||||
}
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func applyOrientation(img image.Image, orientation int) image.Image {
|
||||
if orientation <= 1 || orientation > 8 {
|
||||
return img
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
|
||||
outW, outH := w, h
|
||||
if orientation >= 5 {
|
||||
outW, outH = h, w
|
||||
}
|
||||
|
||||
out := image.NewRGBA(image.Rect(0, 0, outW, outH))
|
||||
for y := range h {
|
||||
for x := range w {
|
||||
var dx, dy int
|
||||
switch orientation {
|
||||
case 2:
|
||||
dx, dy = w-1-x, y
|
||||
case 3:
|
||||
dx, dy = w-1-x, h-1-y
|
||||
case 4:
|
||||
dx, dy = x, h-1-y
|
||||
case 5:
|
||||
dx, dy = y, x
|
||||
case 6:
|
||||
dx, dy = h-1-y, x
|
||||
case 7:
|
||||
dx, dy = h-1-y, w-1-x
|
||||
case 8:
|
||||
dx, dy = y, w-1-x
|
||||
}
|
||||
out.Set(dx, dy, img.At(x+bounds.Min.X, y+bounds.Min.Y))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -6,9 +6,8 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
@@ -33,15 +32,31 @@ type ModelManifest struct {
|
||||
BlobDir string
|
||||
}
|
||||
|
||||
// DefaultBlobDir returns the default blob storage directory.
|
||||
func DefaultBlobDir() string {
|
||||
return filepath.Join(envconfig.Models(), "blobs")
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "linux":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
case "windows":
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
default:
|
||||
return filepath.Join(home, ".ollama", "models", "blobs")
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultManifestDir returns the manifest storage directory.
|
||||
// Respects OLLAMA_MODELS.
|
||||
|
||||
// DefaultManifestDir returns the default manifest storage directory.
|
||||
func DefaultManifestDir() string {
|
||||
return filepath.Join(envconfig.Models(), "manifests")
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "models", "manifests")
|
||||
}
|
||||
|
||||
// LoadManifest loads a manifest for the given model name.
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
|
||||
modelsDir := filepath.Join(t.TempDir(), "models")
|
||||
|
||||
// Simulate packaged/systemd environment
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
t.Setenv("HOME", "/usr/share/ollama")
|
||||
|
||||
// Manifest dir must respect OLLAMA_MODELS
|
||||
wantManifest := filepath.Join(modelsDir, "manifests")
|
||||
if got := DefaultManifestDir(); got != wantManifest {
|
||||
t.Fatalf("DefaultManifestDir() = %q, want %q", got, wantManifest)
|
||||
}
|
||||
|
||||
// Blob dir must respect OLLAMA_MODELS
|
||||
wantBlobs := filepath.Join(modelsDir, "blobs")
|
||||
if got := DefaultBlobDir(); got != wantBlobs {
|
||||
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
|
||||
}
|
||||
}
|
||||
@@ -24,8 +24,9 @@ var SupportedBackends = []string{"metal", "cuda", "cpu"}
|
||||
|
||||
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
|
||||
var modelVRAMEstimates = map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 20 * GB, // ~20GB for Flux
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
|
||||
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
|
||||
}
|
||||
|
||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||
@@ -71,38 +72,26 @@ func ResolveModelName(modelName string) string {
|
||||
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
|
||||
// Returns a conservative default of 21GB if the model type cannot be determined.
|
||||
func EstimateVRAM(modelName string) uint64 {
|
||||
className := DetectModelType(modelName)
|
||||
if estimate, ok := modelVRAMEstimates[className]; ok {
|
||||
return estimate
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// DetectModelType reads model_index.json and returns the model type.
|
||||
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
||||
// Returns empty string if detection fails.
|
||||
func DetectModelType(modelName string) string {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return ""
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return ""
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// Parse just the class name
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ClassName string `json:"_class_name"`
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err != nil {
|
||||
return ""
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// Prefer architecture (Ollama format), fall back to _class_name (diffusers)
|
||||
if index.Architecture != "" {
|
||||
return index.Architecture
|
||||
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
|
||||
return estimate
|
||||
}
|
||||
return index.ClassName
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
@@ -72,8 +72,9 @@ func TestCheckMemoryRequirements(t *testing.T) {
|
||||
func TestModelVRAMEstimates(t *testing.T) {
|
||||
// Verify the VRAM estimates map has expected entries
|
||||
expected := map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 20 * GB,
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 21 * GB,
|
||||
"QwenImagePipeline": 80 * GB,
|
||||
}
|
||||
|
||||
for name, expectedVRAM := range expected {
|
||||
|
||||
@@ -1137,27 +1137,6 @@ func RMSNormNoWeight(x *Array, eps float32) *Array {
|
||||
return RMSNorm(x, ones, eps)
|
||||
}
|
||||
|
||||
// LayerNorm applies layer normalization without learnable params
|
||||
// (x - mean) / sqrt(var + eps)
|
||||
func LayerNorm(x *Array, eps float32) *Array {
|
||||
return LayerNormWithWeightBias(x, nil, nil, eps)
|
||||
}
|
||||
|
||||
// LayerNormWithWeightBias computes layer normalization using mlx.fast
|
||||
// weight and bias can be nil for elementwise_affine=False
|
||||
func LayerNormWithWeightBias(x, weight, bias *Array, eps float32) *Array {
|
||||
res := C.mlx_array_new()
|
||||
var wc, bc C.mlx_array
|
||||
if weight != nil {
|
||||
wc = weight.c
|
||||
}
|
||||
if bias != nil {
|
||||
bc = bias.c
|
||||
}
|
||||
C.mlx_fast_layer_norm(&res, x.c, wc, bc, C.float(eps), C.default_stream())
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
// RoPE applies rotary position embeddings using mlx.fast
|
||||
func RoPE(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
res := C.mlx_array_new()
|
||||
|
||||
@@ -1,539 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package flux2 implements the FLUX.2 Klein diffusion transformer model.
|
||||
// Klein is a 4B parameter distilled model that supports sub-second inference.
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen3"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 4 for Klein)
|
||||
GuidanceScale float32 // Guidance scale (default: 1.0, Klein doesn't need CFG)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
InputImages []image.Image // Reference images for image conditioning (already loaded)
|
||||
}
|
||||
|
||||
// Model represents a FLUX.2 Klein model.
|
||||
type Model struct {
|
||||
ModelName string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *qwen3.TextEncoder
|
||||
Transformer *Flux2Transformer2DModel
|
||||
VAE *AutoencoderKLFlux2
|
||||
SchedulerConfig *SchedulerConfig
|
||||
}
|
||||
|
||||
// TextEncoderLayerIndices are the layers from which to extract text embeddings.
|
||||
// Diffusers uses hidden_states[9, 18, 27]. In Python, hidden_states[0] is the embedding
|
||||
// output before any layers, so hidden_states[9] = after layer 8 (0-indexed).
|
||||
// Go's ForwardWithLayerOutputs captures after layer i runs, so we use [8, 17, 26].
|
||||
var TextEncoderLayerIndices = []int{8, 17, 26}
|
||||
|
||||
// Load loads the FLUX.2 Klein model from ollama blob storage.
|
||||
func (m *Model) Load(modelName string) error {
|
||||
fmt.Printf("Loading FLUX.2 Klein model from manifest: %s...\n", modelName)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
|
||||
tokConfig := &tokenizer.TokenizerConfig{}
|
||||
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = data
|
||||
}
|
||||
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = data
|
||||
}
|
||||
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
|
||||
tokConfig.SpecialTokensMapJSON = data
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder
|
||||
m.TextEncoder = &qwen3.TextEncoder{}
|
||||
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Flux2Transformer2DModel{}
|
||||
if err := m.Transformer.Load(manifest); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
|
||||
// Load VAE
|
||||
m.VAE = &AutoencoderKLFlux2{}
|
||||
if err := m.VAE.Load(manifest); err != nil {
|
||||
return fmt.Errorf("VAE: %w", err)
|
||||
}
|
||||
|
||||
// Evaluate all weights in a single batch (reduces GPU sync overhead)
|
||||
fmt.Print(" Evaluating weights... ")
|
||||
allWeights := mlx.Collect(m.TextEncoder)
|
||||
allWeights = append(allWeights, mlx.Collect(m.Transformer)...)
|
||||
allWeights = append(allWeights, mlx.Collect(m.VAE)...)
|
||||
mlx.Eval(allWeights...)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load scheduler config
|
||||
m.SchedulerConfig = DefaultSchedulerConfig()
|
||||
if schedData, err := manifest.ReadConfig("scheduler/scheduler_config.json"); err == nil {
|
||||
if err := json.Unmarshal(schedData, m.SchedulerConfig); err != nil {
|
||||
fmt.Printf(" Warning: failed to parse scheduler config: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements runner.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
|
||||
const MaxOutputPixels = 2048 * 2048
|
||||
|
||||
// MaxRefPixels is the maximum resolution for reference images (smaller to reduce attention memory)
|
||||
const MaxRefPixels = 728 * 728
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Enable MLX compilation for fused kernels
|
||||
mlx.EnableCompile()
|
||||
|
||||
// Apply defaults
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 4 // Klein default: 4 steps for distilled model
|
||||
}
|
||||
if cfg.GuidanceScale <= 0 {
|
||||
cfg.GuidanceScale = 1.0 // Klein doesn't need guidance
|
||||
}
|
||||
|
||||
// Determine output dimensions
|
||||
if len(cfg.InputImages) > 0 {
|
||||
// With input images, compute missing dimension from aspect ratio
|
||||
// Images are already EXIF-rotated by the caller
|
||||
bounds := cfg.InputImages[0].Bounds()
|
||||
imgW, imgH := bounds.Dx(), bounds.Dy()
|
||||
aspectRatio := float64(imgH) / float64(imgW)
|
||||
if cfg.Width > 0 && cfg.Height <= 0 {
|
||||
// Width specified, compute height
|
||||
cfg.Height = int32(math.Round(float64(cfg.Width)*aspectRatio/16) * 16)
|
||||
} else if cfg.Height > 0 && cfg.Width <= 0 {
|
||||
// Height specified, compute width
|
||||
cfg.Width = int32(math.Round(float64(cfg.Height)/aspectRatio/16) * 16)
|
||||
} else if cfg.Width <= 0 && cfg.Height <= 0 {
|
||||
// Neither specified, use input dimensions
|
||||
cfg.Width = int32(imgW)
|
||||
cfg.Height = int32(imgH)
|
||||
}
|
||||
}
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
|
||||
// Cap to max pixels, preserve aspect ratio, round to multiple of 16
|
||||
pixels := int(cfg.Width) * int(cfg.Height)
|
||||
if pixels > MaxOutputPixels {
|
||||
scale := math.Sqrt(float64(MaxOutputPixels) / float64(pixels))
|
||||
cfg.Width = int32(math.Round(float64(cfg.Width) * scale / 16) * 16)
|
||||
cfg.Height = int32(math.Round(float64(cfg.Height) * scale / 16) * 16)
|
||||
}
|
||||
cfg.Height = int32((cfg.Height + 8) / 16 * 16) // round to nearest 16
|
||||
cfg.Width = int32((cfg.Width + 8) / 16 * 16)
|
||||
fmt.Printf(" Output: %dx%d\n", cfg.Width, cfg.Height)
|
||||
|
||||
tcfg := m.Transformer.TransformerConfig
|
||||
patchSize := m.VAE.Config.PatchSize
|
||||
|
||||
// Latent dimensions: image / 8 (VAE downscale) / patch_size
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
patchH := latentH / patchSize[0]
|
||||
patchW := latentW / patchSize[1]
|
||||
imgSeqLen := patchH * patchW
|
||||
|
||||
// Text encoding with multi-layer extraction (no padding, use true sequence length)
|
||||
fmt.Print(" Encoding prompt... ")
|
||||
promptEmbeds, textLen := m.TextEncoder.EncodePromptWithLayers(m.Tokenizer, cfg.Prompt, 512, TextEncoderLayerIndices, false)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Encode reference images if provided
|
||||
var refTokens *ImageCondTokens
|
||||
var refHeights, refWidths []int32
|
||||
if len(cfg.InputImages) > 0 {
|
||||
fmt.Printf(" Encoding %d reference image(s):\n", len(cfg.InputImages))
|
||||
|
||||
var err error
|
||||
refTokens, err = m.EncodeImageRefs(cfg.InputImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode reference images: %w", err)
|
||||
}
|
||||
|
||||
// Extract heights/widths for RoPE computation (same limits as EncodeImageRefs)
|
||||
limitPixels := MaxRefPixels
|
||||
if len(cfg.InputImages) > 1 {
|
||||
limitPixels = MaxRefPixels / 2
|
||||
}
|
||||
for _, img := range cfg.InputImages {
|
||||
_, w, h := PrepareImage(img, limitPixels)
|
||||
refHeights = append(refHeights, int32(h/16))
|
||||
refWidths = append(refWidths, int32(w/16))
|
||||
}
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(m.SchedulerConfig)
|
||||
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(imgSeqLen, cfg.Steps))
|
||||
|
||||
// Init latents in packed form [B, C*4, H/2, W/2] like diffusers
|
||||
// diffusers creates noise in [B, 128, 64, 64] and packs to [B, 4096, 128]
|
||||
latentChannels := m.VAE.Config.LatentChannels
|
||||
packedChannels := latentChannels * 4 // 32 * 4 = 128
|
||||
latents := scheduler.InitNoise([]int32{1, packedChannels, patchH, patchW}, cfg.Seed)
|
||||
|
||||
// Pack latents (transpose): [B, C, H, W] -> [B, H*W, C]
|
||||
// This matches diffusers' _pack_latents
|
||||
patches := packLatents(latents)
|
||||
noiseSeqLen := patches.Shape()[1]
|
||||
|
||||
// RoPE cache - includes reference images if present
|
||||
rope := PrepareRoPECache(textLen, patchH, patchW, tcfg.AxesDimsRoPE, tcfg.RopeTheta, refHeights, refWidths, ImageRefScale)
|
||||
|
||||
// Cleanup setup arrays when done
|
||||
defer func() {
|
||||
rope.Cos.Free()
|
||||
rope.Sin.Free()
|
||||
promptEmbeds.Free()
|
||||
if refTokens != nil {
|
||||
refTokens.Tokens.Free()
|
||||
}
|
||||
}()
|
||||
|
||||
// Pre-compute all timesteps before the loop to avoid per-step tensor creation
|
||||
timesteps := make([]*mlx.Array, cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
tCurr := scheduler.Timesteps[i] / float32(m.SchedulerConfig.NumTrainTimesteps)
|
||||
timesteps[i] = mlx.ToBFloat16(mlx.NewArray([]float32{tCurr}, []int32{1}))
|
||||
}
|
||||
|
||||
// Evaluate setup arrays
|
||||
fmt.Print(" Evaluating setup... ")
|
||||
setupStart := time.Now()
|
||||
toEval := []*mlx.Array{promptEmbeds, patches, rope.Cos, rope.Sin}
|
||||
toEval = append(toEval, timesteps...)
|
||||
if refTokens != nil {
|
||||
toEval = append(toEval, refTokens.Tokens)
|
||||
}
|
||||
mlx.Eval(toEval...)
|
||||
mlx.MetalResetPeakMemory() // Reset peak to measure generation separately
|
||||
fmt.Printf("✓ (%.2fs, %.1f GB)\n", time.Since(setupStart).Seconds(),
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(0, cfg.Steps)
|
||||
}
|
||||
|
||||
loopStart := time.Now()
|
||||
stepStart := time.Now()
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
// Check for cancellation
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// GPU capture on step 2 if requested
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStartCapture(cfg.CapturePath)
|
||||
}
|
||||
|
||||
timestep := timesteps[i]
|
||||
|
||||
// Prepare input - concatenate noise patches with reference tokens if present
|
||||
imgInput := patches
|
||||
if refTokens != nil {
|
||||
imgInput = mlx.Concatenate([]*mlx.Array{patches, refTokens.Tokens}, 1)
|
||||
}
|
||||
|
||||
// Transformer forward pass
|
||||
output := m.Transformer.Forward(imgInput, promptEmbeds, timestep, rope)
|
||||
|
||||
// If we concatenated reference tokens, slice to only get noise portion
|
||||
if refTokens != nil {
|
||||
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, noiseSeqLen, output.Shape()[2]})
|
||||
}
|
||||
|
||||
// Scheduler step (keep reference to old patches for the computation graph)
|
||||
newPatches := scheduler.Step(output, patches, i)
|
||||
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStopCapture()
|
||||
}
|
||||
|
||||
mlx.Eval(newPatches)
|
||||
patches = newPatches
|
||||
|
||||
elapsed := time.Since(stepStart).Seconds()
|
||||
peakGB := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
if i == 0 {
|
||||
fmt.Printf(" step %d: %.2fs (JIT warmup), peak %.1f GB\n", i+1, elapsed, peakGB)
|
||||
} else {
|
||||
fmt.Printf(" step %d: %.2fs, peak %.1f GB\n", i+1, elapsed, peakGB)
|
||||
}
|
||||
stepStart = time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
}
|
||||
|
||||
loopTime := time.Since(loopStart).Seconds()
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Denoised %d steps in %.2fs (%.2fs/step), peak %.1f GB\n",
|
||||
cfg.Steps, loopTime, loopTime/float64(cfg.Steps), peakMem)
|
||||
|
||||
// Free timesteps now that denoising is done
|
||||
for _, ts := range timesteps {
|
||||
ts.Free()
|
||||
}
|
||||
|
||||
// VAE decode with tiling for larger images
|
||||
fmt.Print(" Decoding VAE... ")
|
||||
vaeStart := time.Now()
|
||||
// Enable tiling for images > 512x512 (latent > 64x64)
|
||||
// VAE attention is O(n²) on latent pixels, tiling reduces memory significantly
|
||||
if patchH*2 > 64 || patchW*2 > 64 {
|
||||
m.VAE.Tiling = DefaultTilingConfig()
|
||||
}
|
||||
decoded := m.VAE.Decode(patches, patchH, patchW)
|
||||
mlx.Eval(decoded)
|
||||
|
||||
// Free patches now that decode is done
|
||||
patches.Free()
|
||||
|
||||
fmt.Printf("✓ (%.2fs, peak %.1f GB)\n", time.Since(vaeStart).Seconds(),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// packLatents converts [B, C, H, W] to [B, H*W, C] (matches diffusers _pack_latents)
|
||||
func packLatents(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
// [B, C, H, W] -> [B, C, H*W] -> [B, H*W, C]
|
||||
x = mlx.Reshape(x, B, C, H*W)
|
||||
return mlx.Transpose(x, 0, 2, 1)
|
||||
}
|
||||
|
||||
// LoadPersistent loads the model and keeps it in memory for repeated use.
|
||||
func LoadPersistent(modelName string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ImageRefScale is the time coordinate offset between reference images (matches diffusers scale=10)
|
||||
const ImageRefScale = 10
|
||||
|
||||
// PrepareImage resizes and crops an image to be a multiple of 16, with optional pixel limit.
|
||||
// Returns the processed image and its dimensions.
|
||||
func PrepareImage(img image.Image, limitPixels int) (image.Image, int, int) {
|
||||
bounds := img.Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
|
||||
// Cap pixels if needed (like diffusers cap_pixels)
|
||||
if limitPixels > 0 && w*h > limitPixels {
|
||||
scale := math.Sqrt(float64(limitPixels) / float64(w*h))
|
||||
w = int(float64(w) * scale)
|
||||
h = int(float64(h) * scale)
|
||||
}
|
||||
|
||||
// Round down to multiple of 16
|
||||
w = (w / 16) * 16
|
||||
h = (h / 16) * 16
|
||||
|
||||
if w < 16 {
|
||||
w = 16
|
||||
}
|
||||
if h < 16 {
|
||||
h = 16
|
||||
}
|
||||
|
||||
// Resize using high-quality bicubic interpolation (matches diffusers' default lanczos)
|
||||
resized := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
draw.CatmullRom.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
return resized, w, h
|
||||
}
|
||||
|
||||
// ImageToTensor converts an image to a tensor in [-1, 1] range with shape [1, C, H, W].
|
||||
func ImageToTensor(img image.Image) *mlx.Array {
|
||||
bounds := img.Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
|
||||
// Convert to float32 array in NCHW format [1, 3, H, W] with values in [-1, 1]
|
||||
data := make([]float32, 3*h*w)
|
||||
|
||||
for y := 0; y < h; y++ {
|
||||
for x := 0; x < w; x++ {
|
||||
r, g, b, _ := img.At(x+bounds.Min.X, y+bounds.Min.Y).RGBA()
|
||||
// RGBA returns 16-bit values, convert to [-1, 1]
|
||||
data[0*h*w+y*w+x] = float32(r>>8)/127.5 - 1.0
|
||||
data[1*h*w+y*w+x] = float32(g>>8)/127.5 - 1.0
|
||||
data[2*h*w+y*w+x] = float32(b>>8)/127.5 - 1.0
|
||||
}
|
||||
}
|
||||
|
||||
arr := mlx.NewArrayFloat32(data, []int32{1, 3, int32(h), int32(w)})
|
||||
return arr
|
||||
}
|
||||
|
||||
// ImageCondTokens holds encoded reference image tokens.
|
||||
type ImageCondTokens struct {
|
||||
Tokens *mlx.Array // [1, total_tokens, C] - concatenated reference tokens
|
||||
}
|
||||
|
||||
// EncodeImageRefs encodes reference images using the VAE.
|
||||
func (m *Model) EncodeImageRefs(images []image.Image) (*ImageCondTokens, error) {
|
||||
if len(images) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Limit reference images to reduce attention memory
|
||||
limitPixels := MaxRefPixels
|
||||
if len(images) > 1 {
|
||||
limitPixels = MaxRefPixels / 2
|
||||
}
|
||||
|
||||
var allTokens []*mlx.Array
|
||||
|
||||
for _, img := range images {
|
||||
// Prepare image (resize, crop to multiple of 16)
|
||||
prepared, prepW, prepH := PrepareImage(img, limitPixels)
|
||||
fmt.Printf(" Encoding %dx%d image... ", prepW, prepH)
|
||||
|
||||
// Convert to tensor [-1, 1]
|
||||
tensor := ImageToTensor(prepared)
|
||||
|
||||
// Encode with VAE - returns [1, L, 128]
|
||||
encoded := m.VAE.EncodeImage(tensor)
|
||||
squeezed := mlx.Squeeze(encoded, 0) // [L, C]
|
||||
|
||||
// Defer eval - will be done with other setup arrays
|
||||
allTokens = append(allTokens, squeezed)
|
||||
fmt.Println("✓")
|
||||
}
|
||||
|
||||
// For single image, just add batch dimension directly
|
||||
// For multiple images, concatenate first
|
||||
var tokens *mlx.Array
|
||||
if len(allTokens) == 1 {
|
||||
tokens = mlx.ExpandDims(allTokens[0], 0) // [1, L, C]
|
||||
} else {
|
||||
tokens = mlx.Concatenate(allTokens, 0) // [total_L, C]
|
||||
tokens = mlx.ExpandDims(tokens, 0) // [1, total_L, C]
|
||||
}
|
||||
|
||||
return &ImageCondTokens{Tokens: tokens}, nil
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// RoPEConfig holds 4D RoPE configuration for Flux2
|
||||
type RoPEConfig struct {
|
||||
Theta int32 // 2000 for Klein
|
||||
AxesDims []int32 // [32, 32, 32, 32] - dimensions for T, H, W, L axes
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE cos/sin values
|
||||
type RoPECache struct {
|
||||
Cos *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
|
||||
Sin *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
|
||||
TextLen int32 // Length of text sequence
|
||||
ImageLen int32 // Length of image sequence
|
||||
}
|
||||
|
||||
// PrepareTextIDs creates position IDs for text tokens.
|
||||
// Text tokens use: T=0, H=0, W=0, L=0..seqLen-1
|
||||
// Returns: [seqLen, 4]
|
||||
func PrepareTextIDs(seqLen int32) *mlx.Array {
|
||||
ids := make([]float32, seqLen*4)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
idx := i * 4
|
||||
ids[idx+0] = 0 // T = 0
|
||||
ids[idx+1] = 0 // H = 0
|
||||
ids[idx+2] = 0 // W = 0
|
||||
ids[idx+3] = float32(i) // L = sequence position
|
||||
}
|
||||
return mlx.NewArray(ids, []int32{seqLen, 4})
|
||||
}
|
||||
|
||||
// PrepareLatentIDs creates position IDs for image latent tokens.
|
||||
// Latent tokens use: T=0, H=0..height-1, W=0..width-1, L=0
|
||||
// The latents are in row-major order (H then W).
|
||||
// Returns: [height*width, 4]
|
||||
func PrepareLatentIDs(height, width int32) *mlx.Array {
|
||||
seqLen := height * width
|
||||
ids := make([]float32, seqLen*4)
|
||||
idx := 0
|
||||
for h := int32(0); h < height; h++ {
|
||||
for w := int32(0); w < width; w++ {
|
||||
ids[idx*4+0] = 0 // T = 0
|
||||
ids[idx*4+1] = float32(h) // H = row
|
||||
ids[idx*4+2] = float32(w) // W = column
|
||||
ids[idx*4+3] = 0 // L = 0
|
||||
idx++
|
||||
}
|
||||
}
|
||||
return mlx.NewArray(ids, []int32{seqLen, 4})
|
||||
}
|
||||
|
||||
// PrepareImageIDs creates position IDs for reference image tokens (used in editing).
|
||||
// Reference images use: T=scale*(i+1), H=0..h-1, W=0..w-1, L=0
|
||||
// where i is the image index (0, 1, 2, ...) and scale separates images in T dimension.
|
||||
// Returns: [total_tokens, 4]
|
||||
func PrepareImageIDs(imageHeights, imageWidths []int32, scale int32) *mlx.Array {
|
||||
// Calculate total tokens
|
||||
totalTokens := int32(0)
|
||||
for i := range imageHeights {
|
||||
totalTokens += imageHeights[i] * imageWidths[i]
|
||||
}
|
||||
|
||||
ids := make([]float32, totalTokens*4)
|
||||
idx := int32(0)
|
||||
for imgIdx, h := range imageHeights {
|
||||
w := imageWidths[imgIdx]
|
||||
tValue := float32(scale * int32(imgIdx+1))
|
||||
for hi := int32(0); hi < h; hi++ {
|
||||
for wi := int32(0); wi < w; wi++ {
|
||||
ids[idx*4+0] = tValue // T = scale * (imgIdx + 1)
|
||||
ids[idx*4+1] = float32(hi) // H = row
|
||||
ids[idx*4+2] = float32(wi) // W = column
|
||||
ids[idx*4+3] = 0 // L = 0
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
return mlx.NewArray(ids, []int32{totalTokens, 4})
|
||||
}
|
||||
|
||||
// ComputeRoPE computes cos and sin for 4D rotary position embeddings.
|
||||
// ids: [L, 4] with (T, H, W, L) coordinates
|
||||
// axesDims: [32, 32, 32, 32] - each axis has this many dimensions (total = head_dim = 128)
|
||||
// theta: base frequency (2000 for Klein)
|
||||
// Returns: cos, sin each [1, L, 1, head_dim] with repeat_interleave applied
|
||||
func ComputeRoPE(ids *mlx.Array, axesDims []int32, theta int32) (*mlx.Array, *mlx.Array) {
|
||||
shape := ids.Shape()
|
||||
seqLen := shape[0]
|
||||
|
||||
// Compute total head dim (sum of all axes dims)
|
||||
headDim := int32(0)
|
||||
for _, d := range axesDims {
|
||||
headDim += d
|
||||
}
|
||||
|
||||
// Extract each coordinate dimension
|
||||
// ids[:, 0] = T, ids[:, 1] = H, ids[:, 2] = W, ids[:, 3] = L
|
||||
posT := mlx.Slice(ids, []int32{0, 0}, []int32{seqLen, 1}) // [L, 1]
|
||||
posH := mlx.Slice(ids, []int32{0, 1}, []int32{seqLen, 2}) // [L, 1]
|
||||
posW := mlx.Slice(ids, []int32{0, 2}, []int32{seqLen, 3}) // [L, 1]
|
||||
posL := mlx.Slice(ids, []int32{0, 3}, []int32{seqLen, 4}) // [L, 1]
|
||||
|
||||
// Compute frequencies for each axis
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
cosArrs := make([]*mlx.Array, 4)
|
||||
sinArrs := make([]*mlx.Array, 4)
|
||||
positions := []*mlx.Array{posT, posH, posW, posL}
|
||||
|
||||
for i, axisDim := range axesDims {
|
||||
half := axisDim / 2
|
||||
|
||||
// Create frequency array for this axis: theta^(-2j/dim) for j=0..half-1
|
||||
// This matches diffusers: 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
|
||||
freqs := make([]float32, half)
|
||||
for j := int32(0); j < half; j++ {
|
||||
freqs[j] = float32(math.Exp(float64(-logTheta * float32(2*j) / float32(axisDim))))
|
||||
}
|
||||
freqArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
// Compute pos * freq -> [L, half]
|
||||
posExpanded := positions[i] // [L, 1]
|
||||
args := mlx.Mul(posExpanded, freqArr) // [L, half]
|
||||
|
||||
// Compute cos and sin for this axis
|
||||
cosAxis := mlx.Cos(args) // [L, half]
|
||||
sinAxis := mlx.Sin(args) // [L, half]
|
||||
|
||||
// repeat_interleave(2): [c0, c1, ...] -> [c0, c0, c1, c1, ...]
|
||||
// Reshape [L, half] -> [L, half, 1], tile to [L, half, 2], reshape to [L, axisDim]
|
||||
cosAxis = mlx.ExpandDims(cosAxis, 2) // [L, half, 1]
|
||||
cosAxis = mlx.Tile(cosAxis, []int32{1, 1, 2}) // [L, half, 2]
|
||||
cosAxis = mlx.Reshape(cosAxis, seqLen, axisDim) // [L, axisDim]
|
||||
|
||||
sinAxis = mlx.ExpandDims(sinAxis, 2)
|
||||
sinAxis = mlx.Tile(sinAxis, []int32{1, 1, 2})
|
||||
sinAxis = mlx.Reshape(sinAxis, seqLen, axisDim)
|
||||
|
||||
cosArrs[i] = cosAxis
|
||||
sinArrs[i] = sinAxis
|
||||
}
|
||||
|
||||
// Concatenate all axes: [L, headDim]
|
||||
cos := mlx.Concatenate(cosArrs, 1)
|
||||
sin := mlx.Concatenate(sinArrs, 1)
|
||||
|
||||
// Reshape to [1, L, 1, headDim] for broadcasting with attention
|
||||
cos = mlx.Reshape(cos, 1, seqLen, 1, headDim)
|
||||
sin = mlx.Reshape(sin, 1, seqLen, 1, headDim)
|
||||
|
||||
return cos, sin
|
||||
}
|
||||
|
||||
// ApplyRoPE4D applies 4D rotary position embeddings to queries and keys.
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// cos, sin: [1, L, 1, head_dim] (with repeat_interleave applied)
|
||||
// Returns: x with RoPE applied
|
||||
// Matches diffusers apply_rotary_emb with use_real=True, use_real_unbind_dim=-1
|
||||
func ApplyRoPE4D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
half := headDim / 2
|
||||
|
||||
// Reshape x to [B, L, nheads, half, 2] and split into real/imag
|
||||
xReshaped := mlx.Reshape(x, B, L, nheads, half, 2)
|
||||
|
||||
// Extract real (index 0) and imag (index 1) parts
|
||||
xReal := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, half, 1})
|
||||
xImag := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, half, 2})
|
||||
xReal = mlx.Squeeze(xReal, 4) // [B, L, nheads, half]
|
||||
xImag = mlx.Squeeze(xImag, 4) // [B, L, nheads, half]
|
||||
|
||||
// x_rotated = stack([-x_imag, x_real], dim=-1).flatten(-2)
|
||||
// This creates [-x_imag[0], x_real[0], -x_imag[1], x_real[1], ...]
|
||||
negXImag := mlx.Neg(xImag)
|
||||
negXImag = mlx.ExpandDims(negXImag, 4) // [B, L, nheads, half, 1]
|
||||
xReal = mlx.ExpandDims(xReal, 4) // [B, L, nheads, half, 1]
|
||||
xRotated := mlx.Concatenate([]*mlx.Array{negXImag, xReal}, 4) // [B, L, nheads, half, 2]
|
||||
xRotated = mlx.Reshape(xRotated, B, L, nheads, headDim) // [B, L, nheads, headDim]
|
||||
|
||||
// out = x * cos + x_rotated * sin
|
||||
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(xRotated, sin))
|
||||
}
|
||||
|
||||
// PrepareRoPECache creates RoPE cache for text + noise, optionally with reference images.
|
||||
// textLen: number of text tokens
|
||||
// noiseH, noiseW: dimensions of the noise latent in patch tokens
|
||||
// axesDims: [32, 32, 32, 32]
|
||||
// theta: 2000
|
||||
// refHeights, refWidths: optional reference image dimensions (pass nil/empty for no images)
|
||||
// scale: time coordinate offset between reference images (e.g., 10)
|
||||
func PrepareRoPECache(textLen, noiseH, noiseW int32, axesDims []int32, theta int32, refHeights, refWidths []int32, scale int32) *RoPECache {
|
||||
textIDs := PrepareTextIDs(textLen)
|
||||
noiseIDs := PrepareLatentIDs(noiseH, noiseW)
|
||||
|
||||
var allIDs *mlx.Array
|
||||
imageLen := noiseH * noiseW
|
||||
|
||||
if len(refHeights) > 0 {
|
||||
refIDs := PrepareImageIDs(refHeights, refWidths, scale)
|
||||
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs, refIDs}, 0)
|
||||
for i := range refHeights {
|
||||
imageLen += refHeights[i] * refWidths[i]
|
||||
}
|
||||
} else {
|
||||
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs}, 0)
|
||||
}
|
||||
|
||||
cos, sin := ComputeRoPE(allIDs, axesDims, theta)
|
||||
cos = mlx.ToBFloat16(cos)
|
||||
sin = mlx.ToBFloat16(sin)
|
||||
|
||||
return &RoPECache{Cos: cos, Sin: sin, TextLen: textLen, ImageLen: imageLen}
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SchedulerConfig holds Flow-Match scheduler configuration
|
||||
type SchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
Shift float32 `json:"shift"` // 3.0 for Klein
|
||||
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
|
||||
TimeShiftType string `json:"time_shift_type"` // "exponential" or "linear"
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns default config for Klein
|
||||
func DefaultSchedulerConfig() *SchedulerConfig {
|
||||
return &SchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
Shift: 3.0, // Klein uses 3.0
|
||||
UseDynamicShifting: true,
|
||||
TimeShiftType: "exponential",
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements the Flow-Match Euler discrete scheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *SchedulerConfig
|
||||
Timesteps []float32 // Discretized timesteps (t from 1 to 0)
|
||||
Sigmas []float32 // Noise levels at each timestep
|
||||
NumSteps int // Number of inference steps
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
func (s *FlowMatchScheduler) SetTimesteps(numSteps int) {
|
||||
s.SetTimestepsWithMu(numSteps, 0)
|
||||
}
|
||||
|
||||
// SetTimestepsWithMu sets up scheduler matching diffusers set_timesteps(sigmas=..., mu=...)
|
||||
func (s *FlowMatchScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// diffusers: sigmas = linspace(1, 1/num_steps, num_steps)
|
||||
// Then applies time shift, appends 0.0 at end
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
|
||||
for i := 0; i < numSteps; i++ {
|
||||
// linspace(1, 1/num_steps, num_steps)
|
||||
var sigma float32
|
||||
if numSteps == 1 {
|
||||
sigma = 1.0
|
||||
} else {
|
||||
sigma = 1.0 - float32(i)/float32(numSteps-1)*(1.0-1.0/float32(numSteps))
|
||||
}
|
||||
|
||||
// Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShifting && mu != 0 {
|
||||
sigma = s.timeShift(mu, sigma)
|
||||
} else {
|
||||
// If not dynamic shifting, apply fixed shift scaling like diffusers
|
||||
shift := s.Config.Shift
|
||||
sigma = shift * sigma / (1 + (shift-1)*sigma)
|
||||
}
|
||||
s.Sigmas[i] = sigma
|
||||
}
|
||||
// Append terminal zero
|
||||
s.Sigmas[numSteps] = 0.0
|
||||
|
||||
// Timesteps scaled to training range (matches diffusers: timesteps = sigmas * num_train_timesteps)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
for i, v := range s.Sigmas {
|
||||
s.Timesteps[i] = v * float32(s.Config.NumTrainTimesteps)
|
||||
}
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift
|
||||
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
if s.Config.TimeShiftType == "linear" {
|
||||
return mu / (mu + (1.0/t-1.0))
|
||||
}
|
||||
// Default: exponential
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
dt := sigmaNext - sigma
|
||||
|
||||
// Upcast to float32 for precision (matches diffusers)
|
||||
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
|
||||
outputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
|
||||
|
||||
scaledOutput := mlx.MulScalar(outputF32, dt)
|
||||
result := mlx.Add(sampleF32, scaledOutput)
|
||||
|
||||
// Cast back to bfloat16
|
||||
return mlx.ToBFloat16(result)
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
|
||||
}
|
||||
|
||||
// CalculateShift computes the mu shift value for dynamic scheduling
|
||||
// Matches diffusers compute_empirical_mu function
|
||||
func CalculateShift(imgSeqLen int32, numSteps int) float32 {
|
||||
a1, b1 := float32(8.73809524e-05), float32(1.89833333)
|
||||
a2, b2 := float32(0.00016927), float32(0.45666666)
|
||||
|
||||
seqLen := float32(imgSeqLen)
|
||||
|
||||
if imgSeqLen > 4300 {
|
||||
return a2*seqLen + b2
|
||||
}
|
||||
|
||||
m200 := a2*seqLen + b2
|
||||
m10 := a1*seqLen + b1
|
||||
|
||||
a := (m200 - m10) / 190.0
|
||||
b := m200 - 200.0*a
|
||||
return a*float32(numSteps) + b
|
||||
}
|
||||
@@ -1,562 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Flux2 transformer configuration
|
||||
type TransformerConfig struct {
|
||||
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
|
||||
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
|
||||
Eps float32 `json:"eps"` // 1e-6
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
|
||||
InChannels int32 `json:"in_channels"` // 128
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
|
||||
MLPRatio float32 `json:"mlp_ratio"` // 3.0
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
|
||||
NumLayers int32 `json:"num_layers"` // 5
|
||||
NumSingleLayers int32 `json:"num_single_layers"` // 20
|
||||
PatchSize int32 `json:"patch_size"` // 1
|
||||
RopeTheta int32 `json:"rope_theta"` // 2000
|
||||
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
|
||||
}
|
||||
|
||||
// Computed dimensions
|
||||
func (c *TransformerConfig) InnerDim() int32 {
|
||||
return c.NumAttentionHeads * c.AttentionHeadDim // 24 * 128 = 3072
|
||||
}
|
||||
|
||||
func (c *TransformerConfig) MLPHiddenDim() int32 {
|
||||
return int32(float32(c.InnerDim()) * c.MLPRatio) // 3072 * 3.0 = 9216
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates timestep embeddings
|
||||
// Weight names: time_guidance_embed.timestep_embedder.linear_1.weight, linear_2.weight
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 nn.LinearLayer `weight:"linear_1"`
|
||||
Linear2 nn.LinearLayer `weight:"linear_2"`
|
||||
EmbedDim int32 // 256
|
||||
}
|
||||
|
||||
// Forward creates sinusoidal embeddings and projects them
|
||||
func (t *TimestepEmbedder) Forward(timesteps *mlx.Array) *mlx.Array {
|
||||
half := t.EmbedDim / 2
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
// timesteps: [B] -> [B, 1]
|
||||
tExpanded := mlx.ExpandDims(timesteps, 1)
|
||||
// args: [B, half]
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
|
||||
// [cos(args), sin(args)] -> [B, embed_dim]
|
||||
sinEmbed := mlx.Concatenate([]*mlx.Array{mlx.Cos(args), mlx.Sin(args)}, 1)
|
||||
|
||||
// MLP: linear_1 -> silu -> linear_2
|
||||
h := t.Linear1.Forward(sinEmbed)
|
||||
h = mlx.SiLU(h)
|
||||
return t.Linear2.Forward(h)
|
||||
}
|
||||
|
||||
// TimeGuidanceEmbed wraps the timestep embedder
|
||||
// Weight names: time_guidance_embed.timestep_embedder.*
|
||||
type TimeGuidanceEmbed struct {
|
||||
TimestepEmbedder *TimestepEmbedder `weight:"timestep_embedder"`
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings
|
||||
func (t *TimeGuidanceEmbed) Forward(timesteps *mlx.Array) *mlx.Array {
|
||||
return t.TimestepEmbedder.Forward(timesteps)
|
||||
}
|
||||
|
||||
// Modulation computes adaptive modulation parameters
|
||||
// Weight names: double_stream_modulation_img.linear.weight, etc.
|
||||
type Modulation struct {
|
||||
Linear nn.LinearLayer `weight:"linear"`
|
||||
}
|
||||
|
||||
// Forward computes modulation parameters
|
||||
func (m *Modulation) Forward(temb *mlx.Array) *mlx.Array {
|
||||
h := mlx.SiLU(temb)
|
||||
return m.Linear.Forward(h)
|
||||
}
|
||||
|
||||
// TransformerBlockAttn implements dual-stream attention
|
||||
// Weight names: transformer_blocks.N.attn.*
|
||||
type TransformerBlockAttn struct {
|
||||
// Image stream (separate Q, K, V projections)
|
||||
ToQ nn.LinearLayer `weight:"to_q"`
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
// Note: to_out has .0 suffix in weights, handled specially
|
||||
ToOut0 nn.LinearLayer `weight:"to_out.0"`
|
||||
|
||||
// Text stream (add_ projections)
|
||||
AddQProj nn.LinearLayer `weight:"add_q_proj"`
|
||||
AddKProj nn.LinearLayer `weight:"add_k_proj"`
|
||||
AddVProj nn.LinearLayer `weight:"add_v_proj"`
|
||||
ToAddOut nn.LinearLayer `weight:"to_add_out"`
|
||||
|
||||
// QK norms for image stream
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"`
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
|
||||
// QK norms for text stream (added)
|
||||
NormAddedQ *mlx.Array `weight:"norm_added_q.weight"`
|
||||
NormAddedK *mlx.Array `weight:"norm_added_k.weight"`
|
||||
}
|
||||
|
||||
// FeedForward implements SwiGLU MLP
|
||||
// Weight names: transformer_blocks.N.ff.linear_in.weight, linear_out.weight
|
||||
type FeedForward struct {
|
||||
LinearIn nn.LinearLayer `weight:"linear_in"`
|
||||
LinearOut nn.LinearLayer `weight:"linear_out"`
|
||||
}
|
||||
|
||||
// Forward applies SwiGLU MLP
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
// LinearIn outputs 2x hidden dim for SwiGLU
|
||||
h := ff.LinearIn.Forward(x)
|
||||
shape := h.Shape()
|
||||
half := shape[len(shape)-1] / 2
|
||||
|
||||
// Split into gate and up
|
||||
gate := mlx.Slice(h, []int32{0, 0, 0}, []int32{shape[0], shape[1], half})
|
||||
up := mlx.Slice(h, []int32{0, 0, half}, []int32{shape[0], shape[1], shape[2]})
|
||||
|
||||
// SwiGLU: silu(gate) * up
|
||||
h = mlx.Mul(mlx.SiLU(gate), up)
|
||||
return ff.LinearOut.Forward(h)
|
||||
}
|
||||
|
||||
// TransformerBlock implements a dual-stream transformer block
|
||||
// Weight names: transformer_blocks.N.*
|
||||
type TransformerBlock struct {
|
||||
Attn *TransformerBlockAttn `weight:"attn"`
|
||||
FF *FeedForward `weight:"ff"`
|
||||
FFContext *FeedForward `weight:"ff_context"`
|
||||
|
||||
// Config (set after loading)
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// Forward applies the dual-stream block
|
||||
// imgHidden: [B, imgLen, dim]
|
||||
// txtHidden: [B, txtLen, dim]
|
||||
// imgMod, txtMod: modulation params [B, 6*dim] each
|
||||
// cos, sin: RoPE values
|
||||
func (block *TransformerBlock) Forward(imgHidden, txtHidden *mlx.Array, imgMod, txtMod *mlx.Array, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
imgShape := imgHidden.Shape()
|
||||
B := imgShape[0]
|
||||
imgLen := imgShape[1]
|
||||
dim := imgShape[2]
|
||||
txtLen := txtHidden.Shape()[1]
|
||||
|
||||
// Parse modulation: 6 params each (shift1, scale1, gate1, shift2, scale2, gate2)
|
||||
imgShift1, imgScale1, imgGate1 := parseModulation3(imgMod, dim, 0)
|
||||
imgShift2, imgScale2, imgGate2 := parseModulation3(imgMod, dim, 3)
|
||||
txtShift1, txtScale1, txtGate1 := parseModulation3(txtMod, dim, 0)
|
||||
txtShift2, txtScale2, txtGate2 := parseModulation3(txtMod, dim, 3)
|
||||
|
||||
// === Attention branch ===
|
||||
// Modulate inputs
|
||||
imgNorm := modulateLayerNorm(imgHidden, imgShift1, imgScale1)
|
||||
txtNorm := modulateLayerNorm(txtHidden, txtShift1, txtScale1)
|
||||
|
||||
// Compute Q, K, V for image stream (separate projections)
|
||||
imgQ := block.Attn.ToQ.Forward(imgNorm)
|
||||
imgK := block.Attn.ToK.Forward(imgNorm)
|
||||
imgV := block.Attn.ToV.Forward(imgNorm)
|
||||
|
||||
// Compute Q, K, V for text stream (add_ projections)
|
||||
txtQ := block.Attn.AddQProj.Forward(txtNorm)
|
||||
txtK := block.Attn.AddKProj.Forward(txtNorm)
|
||||
txtV := block.Attn.AddVProj.Forward(txtNorm)
|
||||
|
||||
// Reshape for attention: [B, L, dim] -> [B, L, nheads, headDim]
|
||||
imgQ = mlx.Reshape(imgQ, B, imgLen, block.NHeads, block.HeadDim)
|
||||
imgK = mlx.Reshape(imgK, B, imgLen, block.NHeads, block.HeadDim)
|
||||
imgV = mlx.Reshape(imgV, B, imgLen, block.NHeads, block.HeadDim)
|
||||
txtQ = mlx.Reshape(txtQ, B, txtLen, block.NHeads, block.HeadDim)
|
||||
txtK = mlx.Reshape(txtK, B, txtLen, block.NHeads, block.HeadDim)
|
||||
txtV = mlx.Reshape(txtV, B, txtLen, block.NHeads, block.HeadDim)
|
||||
|
||||
// Apply QK norm (RMSNorm with learned scale)
|
||||
imgQ = applyQKNorm(imgQ, block.Attn.NormQ)
|
||||
imgK = applyQKNorm(imgK, block.Attn.NormK)
|
||||
txtQ = applyQKNorm(txtQ, block.Attn.NormAddedQ)
|
||||
txtK = applyQKNorm(txtK, block.Attn.NormAddedK)
|
||||
|
||||
// Concatenate for joint attention: text first, then image
|
||||
q := mlx.Concatenate([]*mlx.Array{txtQ, imgQ}, 1)
|
||||
k := mlx.Concatenate([]*mlx.Array{txtK, imgK}, 1)
|
||||
v := mlx.Concatenate([]*mlx.Array{txtV, imgV}, 1)
|
||||
|
||||
// Apply RoPE
|
||||
q = ApplyRoPE4D(q, cos, sin)
|
||||
k = ApplyRoPE4D(k, cos, sin)
|
||||
|
||||
// Transpose for SDPA: [B, nheads, L, headDim]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// Scaled dot-product attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
|
||||
|
||||
// Transpose back: [B, L, nheads, headDim]
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
|
||||
// Split back into txt and img
|
||||
totalLen := txtLen + imgLen
|
||||
txtOut := mlx.Slice(out, []int32{0, 0, 0, 0}, []int32{B, txtLen, block.NHeads, block.HeadDim})
|
||||
imgOut := mlx.Slice(out, []int32{0, txtLen, 0, 0}, []int32{B, totalLen, block.NHeads, block.HeadDim})
|
||||
|
||||
// Reshape and project
|
||||
txtOut = mlx.Reshape(txtOut, B, txtLen, dim)
|
||||
imgOut = mlx.Reshape(imgOut, B, imgLen, dim)
|
||||
txtOut = block.Attn.ToAddOut.Forward(txtOut)
|
||||
imgOut = block.Attn.ToOut0.Forward(imgOut)
|
||||
|
||||
// Apply gates and residual
|
||||
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate1, imgOut))
|
||||
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate1, txtOut))
|
||||
|
||||
// === MLP branch ===
|
||||
imgNorm = modulateLayerNorm(imgHidden, imgShift2, imgScale2)
|
||||
txtNorm = modulateLayerNorm(txtHidden, txtShift2, txtScale2)
|
||||
|
||||
imgFFOut := block.FF.Forward(imgNorm)
|
||||
txtFFOut := block.FFContext.Forward(txtNorm)
|
||||
|
||||
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate2, imgFFOut))
|
||||
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate2, txtFFOut))
|
||||
|
||||
return imgHidden, txtHidden
|
||||
}
|
||||
|
||||
// SingleTransformerBlockAttn implements attention for single-stream blocks
|
||||
// Weight names: single_transformer_blocks.N.attn.*
|
||||
type SingleTransformerBlockAttn struct {
|
||||
ToQKVMlpProj nn.LinearLayer `weight:"to_qkv_mlp_proj"` // Fused QKV + MLP input
|
||||
ToOut nn.LinearLayer `weight:"to_out"` // Fused attn_out + MLP out
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"`
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
}
|
||||
|
||||
// SingleTransformerBlock implements a single-stream transformer block
|
||||
// Weight names: single_transformer_blocks.N.*
|
||||
type SingleTransformerBlock struct {
|
||||
Attn *SingleTransformerBlockAttn `weight:"attn"`
|
||||
|
||||
// Config
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
InnerDim int32
|
||||
MLPHidDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// Forward applies the single-stream block
|
||||
// x: [B, L, dim] concatenated text+image
|
||||
// mod: modulation [B, 3*dim]
|
||||
func (block *SingleTransformerBlock) Forward(x *mlx.Array, mod *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
dim := shape[2]
|
||||
|
||||
// Parse modulation: (shift, scale, gate)
|
||||
shift, scale, gate := parseModulation3(mod, dim, 0)
|
||||
|
||||
// Modulate input
|
||||
h := modulateLayerNorm(x, shift, scale)
|
||||
|
||||
// Fused projection: QKV + MLP gate/up
|
||||
// linear1 outputs: [q, k, v, mlp_gate, mlp_up] = [dim, dim, dim, mlpHid, mlpHid]
|
||||
qkvMlp := block.Attn.ToQKVMlpProj.Forward(h)
|
||||
|
||||
// Split: first 3*dim is QKV, rest is MLP
|
||||
qkvDim := 3 * block.InnerDim
|
||||
qkv := mlx.Slice(qkvMlp, []int32{0, 0, 0}, []int32{B, L, qkvDim})
|
||||
mlpIn := mlx.Slice(qkvMlp, []int32{0, 0, qkvDim}, []int32{B, L, qkvMlp.Shape()[2]})
|
||||
|
||||
// Split QKV
|
||||
q, k, v := splitQKV(qkv, B, L, block.InnerDim)
|
||||
|
||||
// Reshape for attention
|
||||
q = mlx.Reshape(q, B, L, block.NHeads, block.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, block.NHeads, block.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, block.NHeads, block.HeadDim)
|
||||
|
||||
// QK norm
|
||||
q = applyQKNorm(q, block.Attn.NormQ)
|
||||
k = applyQKNorm(k, block.Attn.NormK)
|
||||
|
||||
// Apply RoPE
|
||||
q = ApplyRoPE4D(q, cos, sin)
|
||||
k = ApplyRoPE4D(k, cos, sin)
|
||||
|
||||
// Transpose for SDPA
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
attnOut := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
|
||||
|
||||
// Transpose back and reshape
|
||||
attnOut = mlx.Transpose(attnOut, 0, 2, 1, 3)
|
||||
attnOut = mlx.Reshape(attnOut, B, L, block.InnerDim)
|
||||
|
||||
// MLP: SwiGLU
|
||||
mlpShape := mlpIn.Shape()
|
||||
half := mlpShape[2] / 2
|
||||
mlpGate := mlx.Slice(mlpIn, []int32{0, 0, 0}, []int32{B, L, half})
|
||||
mlpUp := mlx.Slice(mlpIn, []int32{0, 0, half}, []int32{B, L, mlpShape[2]})
|
||||
mlpOut := mlx.Mul(mlx.SiLU(mlpGate), mlpUp)
|
||||
|
||||
// Concatenate attention and MLP for fused output
|
||||
combined := mlx.Concatenate([]*mlx.Array{attnOut, mlpOut}, 2)
|
||||
|
||||
// Output projection
|
||||
out := block.Attn.ToOut.Forward(combined)
|
||||
|
||||
// Apply gate and residual
|
||||
return mlx.Add(x, mlx.Mul(gate, out))
|
||||
}
|
||||
|
||||
// NormOut implements the output normalization with modulation
|
||||
// Weight names: norm_out.linear.weight
|
||||
type NormOut struct {
|
||||
Linear nn.LinearLayer `weight:"linear"`
|
||||
}
|
||||
|
||||
// Forward computes final modulated output
|
||||
func (n *NormOut) Forward(x *mlx.Array, temb *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
dim := shape[2]
|
||||
|
||||
// Modulation: temb -> silu -> linear -> [shift, scale]
|
||||
mod := mlx.SiLU(temb)
|
||||
mod = n.Linear.Forward(mod)
|
||||
|
||||
// Split into scale and shift (diffusers order: scale first, shift second)
|
||||
scale := mlx.Slice(mod, []int32{0, 0}, []int32{B, dim})
|
||||
shift := mlx.Slice(mod, []int32{0, dim}, []int32{B, 2 * dim})
|
||||
shift = mlx.ExpandDims(shift, 1)
|
||||
scale = mlx.ExpandDims(scale, 1)
|
||||
|
||||
// Modulate with RMSNorm
|
||||
return modulateLayerNorm(x, shift, scale)
|
||||
}
|
||||
|
||||
// Flux2Transformer2DModel is the main Flux2 transformer
|
||||
// Weight names at top level: time_guidance_embed.*, double_stream_modulation_*.*, etc.
|
||||
type Flux2Transformer2DModel struct {
|
||||
// Timestep embedding
|
||||
TimeGuidanceEmbed *TimeGuidanceEmbed `weight:"time_guidance_embed"`
|
||||
|
||||
// Shared modulation
|
||||
DoubleStreamModulationImg *Modulation `weight:"double_stream_modulation_img"`
|
||||
DoubleStreamModulationTxt *Modulation `weight:"double_stream_modulation_txt"`
|
||||
SingleStreamModulation *Modulation `weight:"single_stream_modulation"`
|
||||
|
||||
// Embedders
|
||||
XEmbedder nn.LinearLayer `weight:"x_embedder"`
|
||||
ContextEmbedder nn.LinearLayer `weight:"context_embedder"`
|
||||
|
||||
// Transformer blocks
|
||||
TransformerBlocks []*TransformerBlock `weight:"transformer_blocks"`
|
||||
SingleTransformerBlocks []*SingleTransformerBlock `weight:"single_transformer_blocks"`
|
||||
|
||||
// Output
|
||||
NormOut *NormOut `weight:"norm_out"`
|
||||
ProjOut nn.LinearLayer `weight:"proj_out"`
|
||||
|
||||
*TransformerConfig
|
||||
}
|
||||
|
||||
// Load loads the Flux2 transformer from ollama blob storage.
|
||||
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading transformer... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg TransformerConfig
|
||||
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.TransformerConfig = &cfg
|
||||
|
||||
// Initialize slices
|
||||
m.TransformerBlocks = make([]*TransformerBlock, cfg.NumLayers)
|
||||
m.SingleTransformerBlocks = make([]*SingleTransformerBlock, cfg.NumSingleLayers)
|
||||
|
||||
// Initialize TimeGuidanceEmbed with embed dim
|
||||
m.TimeGuidanceEmbed = &TimeGuidanceEmbed{
|
||||
TimestepEmbedder: &TimestepEmbedder{EmbedDim: cfg.TimestepGuidanceChannels},
|
||||
}
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *Flux2Transformer2DModel) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *Flux2Transformer2DModel) initComputedFields() {
|
||||
cfg := m.TransformerConfig
|
||||
innerDim := cfg.InnerDim()
|
||||
scale := float32(1.0 / math.Sqrt(float64(cfg.AttentionHeadDim)))
|
||||
|
||||
// Initialize transformer blocks
|
||||
for _, block := range m.TransformerBlocks {
|
||||
block.NHeads = cfg.NumAttentionHeads
|
||||
block.HeadDim = cfg.AttentionHeadDim
|
||||
block.Scale = scale
|
||||
}
|
||||
|
||||
// Initialize single transformer blocks
|
||||
for _, block := range m.SingleTransformerBlocks {
|
||||
block.NHeads = cfg.NumAttentionHeads
|
||||
block.HeadDim = cfg.AttentionHeadDim
|
||||
block.InnerDim = innerDim
|
||||
block.MLPHidDim = cfg.MLPHiddenDim()
|
||||
block.Scale = scale
|
||||
}
|
||||
}
|
||||
|
||||
// Forward runs the Flux2 transformer
|
||||
func (m *Flux2Transformer2DModel) Forward(patches, txtEmbeds *mlx.Array, timesteps *mlx.Array, rope *RoPECache) *mlx.Array {
|
||||
patchShape := patches.Shape()
|
||||
B := patchShape[0]
|
||||
imgLen := patchShape[1]
|
||||
txtLen := txtEmbeds.Shape()[1]
|
||||
|
||||
// Scale timestep to 0-1000 range (diffusers multiplies by 1000)
|
||||
scaledTimesteps := mlx.MulScalar(timesteps, 1000.0)
|
||||
|
||||
// Compute timestep embedding
|
||||
temb := m.TimeGuidanceEmbed.Forward(scaledTimesteps)
|
||||
|
||||
// Embed patches and text
|
||||
imgHidden := m.XEmbedder.Forward(patches)
|
||||
txtHidden := m.ContextEmbedder.Forward(txtEmbeds)
|
||||
|
||||
// Compute shared modulation
|
||||
imgMod := m.DoubleStreamModulationImg.Forward(temb)
|
||||
txtMod := m.DoubleStreamModulationTxt.Forward(temb)
|
||||
singleMod := m.SingleStreamModulation.Forward(temb)
|
||||
|
||||
// Double (dual-stream) blocks
|
||||
for _, block := range m.TransformerBlocks {
|
||||
imgHidden, txtHidden = block.Forward(imgHidden, txtHidden, imgMod, txtMod, rope.Cos, rope.Sin)
|
||||
}
|
||||
|
||||
// Concatenate for single-stream: text first, then image
|
||||
hidden := mlx.Concatenate([]*mlx.Array{txtHidden, imgHidden}, 1)
|
||||
|
||||
// Single-stream blocks
|
||||
for _, block := range m.SingleTransformerBlocks {
|
||||
hidden = block.Forward(hidden, singleMod, rope.Cos, rope.Sin)
|
||||
}
|
||||
|
||||
// Extract image portion
|
||||
totalLen := txtLen + imgLen
|
||||
imgOut := mlx.Slice(hidden, []int32{0, txtLen, 0}, []int32{B, totalLen, hidden.Shape()[2]})
|
||||
|
||||
// Final norm and projection
|
||||
imgOut = m.NormOut.Forward(imgOut, temb)
|
||||
return m.ProjOut.Forward(imgOut)
|
||||
}
|
||||
|
||||
// Note: QK normalization uses mlx.RMSNorm (the fast version) directly
|
||||
// See applyQKNorm function below
|
||||
|
||||
// compiledSwiGLU fuses: silu(gate) * up
|
||||
// Called 30x per step (10 in dual-stream + 20 in single-stream blocks)
|
||||
var compiledSwiGLU *mlx.CompiledFunc
|
||||
|
||||
func getCompiledSwiGLU() *mlx.CompiledFunc {
|
||||
if compiledSwiGLU == nil {
|
||||
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
gate, up := inputs[0], inputs[1]
|
||||
return []*mlx.Array{mlx.Mul(mlx.SiLU(gate), up)}
|
||||
}, true)
|
||||
}
|
||||
return compiledSwiGLU
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// parseModulation3 extracts 3 modulation params (shift, scale, gate) starting at offset
|
||||
func parseModulation3(mod *mlx.Array, dim int32, offset int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
|
||||
B := mod.Shape()[0]
|
||||
start := offset * dim
|
||||
shift := mlx.Slice(mod, []int32{0, start}, []int32{B, start + dim})
|
||||
scale := mlx.Slice(mod, []int32{0, start + dim}, []int32{B, start + 2*dim})
|
||||
gate := mlx.Slice(mod, []int32{0, start + 2*dim}, []int32{B, start + 3*dim})
|
||||
|
||||
// Expand for broadcasting [B, dim] -> [B, 1, dim]
|
||||
shift = mlx.ExpandDims(shift, 1)
|
||||
scale = mlx.ExpandDims(scale, 1)
|
||||
gate = mlx.ExpandDims(gate, 1)
|
||||
|
||||
return shift, scale, gate
|
||||
}
|
||||
|
||||
// modulateLayerNorm applies LayerNorm then shift/scale modulation
|
||||
// Diffusers uses LayerNorm(elementwise_affine=False) which centers the data
|
||||
func modulateLayerNorm(x *mlx.Array, shift, scale *mlx.Array) *mlx.Array {
|
||||
// Fast LayerNorm without learnable params
|
||||
x = mlx.LayerNorm(x, 1e-6)
|
||||
|
||||
// Modulate: x * (1 + scale) + shift
|
||||
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
|
||||
return mlx.Add(x, shift)
|
||||
}
|
||||
|
||||
// splitQKV splits a fused QKV tensor into Q, K, V
|
||||
func splitQKV(qkv *mlx.Array, B, L, dim int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0}, []int32{B, L, dim})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, dim}, []int32{B, L, 2 * dim})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 2 * dim}, []int32{B, L, 3 * dim})
|
||||
return q, k, v
|
||||
}
|
||||
|
||||
// applyQKNorm applies RMSNorm with learned scale (no bias)
|
||||
// Uses the optimized mlx_fast_rms_norm
|
||||
func applyQKNorm(x *mlx.Array, scale *mlx.Array) *mlx.Array {
|
||||
return mlx.RMSNorm(x, scale, 1e-6)
|
||||
}
|
||||
@@ -1,804 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
)
|
||||
|
||||
// VAEConfig holds AutoencoderKLFlux2 configuration
|
||||
type VAEConfig struct {
|
||||
ActFn string `json:"act_fn"` // "silu"
|
||||
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
|
||||
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
|
||||
ForceUpcast bool `json:"force_upcast"` // true
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 32
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 2
|
||||
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
PatchSize []int32 `json:"patch_size"` // [2, 2]
|
||||
SampleSize int32 `json:"sample_size"` // 1024
|
||||
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
|
||||
UseQuantConv bool `json:"use_quant_conv"` // true
|
||||
}
|
||||
|
||||
// BatchNorm2D implements 2D batch normalization with running statistics
|
||||
type BatchNorm2D struct {
|
||||
RunningMean *mlx.Array // [C]
|
||||
RunningVar *mlx.Array // [C]
|
||||
Weight *mlx.Array // [C] gamma
|
||||
Bias *mlx.Array // [C] beta
|
||||
Eps float32
|
||||
Momentum float32
|
||||
}
|
||||
|
||||
// Forward applies batch normalization (inference mode - uses running stats)
|
||||
// Input and output are in NHWC format [B, H, W, C]
|
||||
func (bn *BatchNorm2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[3]
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, 1, C]
|
||||
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
|
||||
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
|
||||
|
||||
// Normalize: (x - mean) / sqrt(var + eps)
|
||||
xNorm := mlx.Sub(x, mean)
|
||||
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
|
||||
|
||||
// Scale and shift (only if affine=True)
|
||||
if bn.Weight != nil {
|
||||
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if bn.Bias != nil {
|
||||
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Denormalize inverts the batch normalization
|
||||
// Used when decoding latents
|
||||
func (bn *BatchNorm2D) Denormalize(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[3]
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, 1, C]
|
||||
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
|
||||
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
|
||||
|
||||
// Inverse: first undo affine, then undo normalization
|
||||
// For affine=False: x_denorm = x * sqrt(var + eps) + mean
|
||||
if bn.Bias != nil {
|
||||
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
|
||||
x = mlx.Sub(x, bias)
|
||||
}
|
||||
if bn.Weight != nil {
|
||||
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
|
||||
x = mlx.Div(x, weight)
|
||||
}
|
||||
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
|
||||
x = mlx.Add(x, mean)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// GroupNormLayer implements group normalization
|
||||
// Reused from zimage package pattern
|
||||
type GroupNormLayer struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
NumGroups int32
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// Forward applies group normalization
|
||||
// Input and output are in NHWC format [B, H, W, C]
|
||||
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
// Reshape to [B, H, W, groups, C/groups]
|
||||
groupSize := C / gn.NumGroups
|
||||
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 1, true)
|
||||
mean = mlx.Mean(mean, 2, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
|
||||
sq := mlx.Square(xCentered)
|
||||
variance := mlx.Mean(sq, 1, true)
|
||||
variance = mlx.Mean(variance, 2, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back to [B, H, W, C]
|
||||
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
||||
|
||||
// Scale and shift
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Conv2D represents a 2D convolution layer (reused pattern)
|
||||
type Conv2D struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias,optional"`
|
||||
Stride int32
|
||||
Padding int32
|
||||
}
|
||||
|
||||
// Transform implements safetensors.Transformer to transpose weights from PyTorch's OIHW to MLX's OHWI.
|
||||
func (conv *Conv2D) Transform(field string, arr *mlx.Array) *mlx.Array {
|
||||
if field == "Weight" {
|
||||
return mlx.Transpose(arr, 0, 2, 3, 1)
|
||||
}
|
||||
return arr
|
||||
}
|
||||
|
||||
// Forward applies convolution (NHWC format)
|
||||
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
|
||||
|
||||
if conv.Bias != nil {
|
||||
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// ResnetBlock2D implements a ResNet block for VAE
|
||||
type ResnetBlock2D struct {
|
||||
Norm1 *GroupNormLayer `weight:"norm1"`
|
||||
Conv1 *Conv2D `weight:"conv1"`
|
||||
Norm2 *GroupNormLayer `weight:"norm2"`
|
||||
Conv2 *Conv2D `weight:"conv2"`
|
||||
ConvShortcut *Conv2D `weight:"conv_shortcut,optional"`
|
||||
}
|
||||
|
||||
// Forward applies the ResNet block
|
||||
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
h := rb.Norm1.Forward(x)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
|
||||
h = rb.Norm2.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
|
||||
if rb.ConvShortcut != nil {
|
||||
x = rb.ConvShortcut.Forward(x)
|
||||
}
|
||||
|
||||
return mlx.Add(h, x)
|
||||
}
|
||||
|
||||
// VAEAttentionBlock implements self-attention for VAE
|
||||
type VAEAttentionBlock struct {
|
||||
GroupNorm *GroupNormLayer `weight:"group_norm"`
|
||||
ToQ nn.LinearLayer `weight:"to_q"`
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
ToOut nn.LinearLayer `weight:"to_out.0"`
|
||||
}
|
||||
|
||||
// Forward applies attention (NHWC format)
|
||||
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
h := ab.GroupNorm.Forward(x)
|
||||
h = mlx.Reshape(h, B, H*W, C)
|
||||
|
||||
q := ab.ToQ.Forward(h)
|
||||
k := ab.ToK.Forward(h)
|
||||
v := ab.ToV.Forward(h)
|
||||
|
||||
q = mlx.ExpandDims(q, 1)
|
||||
k = mlx.ExpandDims(k, 1)
|
||||
v = mlx.ExpandDims(v, 1)
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
out = mlx.Squeeze(out, 1)
|
||||
|
||||
out = ab.ToOut.Forward(out)
|
||||
out = mlx.Reshape(out, B, H, W, C)
|
||||
out = mlx.Add(out, residual)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// UpDecoderBlock2D implements an upsampling decoder block
|
||||
type UpDecoderBlock2D struct {
|
||||
ResnetBlocks []*ResnetBlock2D
|
||||
Upsample *Conv2D
|
||||
}
|
||||
|
||||
// Forward applies the up decoder block
|
||||
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range ub.ResnetBlocks {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
|
||||
if ub.Upsample != nil {
|
||||
x = upsample2x(x)
|
||||
x = ub.Upsample.Forward(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2x performs 2x nearest neighbor upsampling
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
|
||||
hIdx = mlx.Reshape(hIdx, H, 1)
|
||||
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
|
||||
hIdx = mlx.Reshape(hIdx, H*2)
|
||||
|
||||
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
|
||||
wIdx = mlx.Reshape(wIdx, W, 1)
|
||||
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
|
||||
wIdx = mlx.Reshape(wIdx, W*2)
|
||||
|
||||
x = mlx.Take(x, hIdx, 1)
|
||||
x = mlx.Take(x, wIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEMidBlock is the middle block with attention
|
||||
type VAEMidBlock struct {
|
||||
Resnet1 *ResnetBlock2D
|
||||
Attention *VAEAttentionBlock
|
||||
Resnet2 *ResnetBlock2D
|
||||
}
|
||||
|
||||
// Forward applies the mid block
|
||||
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
x = mb.Resnet1.Forward(x)
|
||||
x = mb.Attention.Forward(x)
|
||||
x = mb.Resnet2.Forward(x)
|
||||
return x
|
||||
}
|
||||
|
||||
// DefaultTilingConfig returns reasonable defaults for tiled decoding
|
||||
// Matches diffusers: tile_latent_min_size=64, tile_overlap_factor=0.25
|
||||
func DefaultTilingConfig() *vae.TilingConfig {
|
||||
return vae.DefaultTilingConfig()
|
||||
}
|
||||
|
||||
// AutoencoderKLFlux2 is the Flux2 VAE with BatchNorm
|
||||
type AutoencoderKLFlux2 struct {
|
||||
Config *VAEConfig
|
||||
|
||||
// Encoder components (for image editing)
|
||||
EncoderConvIn *Conv2D
|
||||
EncoderMid *VAEMidBlock
|
||||
EncoderDown []*DownEncoderBlock2D
|
||||
EncoderNormOut *GroupNormLayer
|
||||
EncoderConvOut *Conv2D
|
||||
|
||||
// Decoder components
|
||||
DecoderConvIn *Conv2D
|
||||
DecoderMid *VAEMidBlock
|
||||
DecoderUp []*UpDecoderBlock2D
|
||||
DecoderNormOut *GroupNormLayer
|
||||
DecoderConvOut *Conv2D
|
||||
|
||||
// Quant conv layers
|
||||
QuantConv *Conv2D
|
||||
PostQuantConv *Conv2D
|
||||
|
||||
// BatchNorm for latent normalization
|
||||
LatentBN *BatchNorm2D
|
||||
|
||||
// Tiling configuration (nil = no tiling)
|
||||
Tiling *vae.TilingConfig
|
||||
}
|
||||
|
||||
// DownEncoderBlock2D implements a downsampling encoder block
|
||||
type DownEncoderBlock2D struct {
|
||||
ResnetBlocks []*ResnetBlock2D
|
||||
Downsample *Conv2D
|
||||
}
|
||||
|
||||
// Forward applies the down encoder block
|
||||
func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range db.ResnetBlocks {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
|
||||
if db.Downsample != nil {
|
||||
// Pad then conv with stride 2
|
||||
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0})
|
||||
x = db.Downsample.Forward(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Load loads the Flux2 VAE from ollama blob storage.
|
||||
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading VAE... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights, &cfg)
|
||||
}
|
||||
|
||||
// loadWeights loads VAE weights from any WeightSource
|
||||
func (m *AutoencoderKLFlux2) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
|
||||
var err error
|
||||
|
||||
// Load encoder components (for image conditioning)
|
||||
if err := m.loadEncoderWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("encoder: %w", err)
|
||||
}
|
||||
|
||||
// Load decoder conv_in
|
||||
m.DecoderConvIn = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.DecoderConvIn, weights, "decoder.conv_in"); err != nil {
|
||||
return fmt.Errorf("decoder.conv_in: %w", err)
|
||||
}
|
||||
|
||||
// Load mid block
|
||||
m.DecoderMid, err = loadVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decoder.mid_block: %w", err)
|
||||
}
|
||||
|
||||
// Load up blocks
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.DecoderUp = make([]*UpDecoderBlock2D, numBlocks)
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
hasUpsample := i < numBlocks-1
|
||||
m.DecoderUp[i], err = loadUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", prefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load decoder conv_norm_out and conv_out
|
||||
m.DecoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
|
||||
if err := safetensors.LoadModule(m.DecoderNormOut, weights, "decoder.conv_norm_out"); err != nil {
|
||||
return fmt.Errorf("decoder.conv_norm_out: %w", err)
|
||||
}
|
||||
|
||||
m.DecoderConvOut = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.DecoderConvOut, weights, "decoder.conv_out"); err != nil {
|
||||
return fmt.Errorf("decoder.conv_out: %w", err)
|
||||
}
|
||||
|
||||
// Load post_quant_conv
|
||||
if cfg.UsePostQuantConv {
|
||||
m.PostQuantConv = &Conv2D{Stride: 1, Padding: 0}
|
||||
if err := safetensors.LoadModule(m.PostQuantConv, weights, "post_quant_conv"); err != nil {
|
||||
return fmt.Errorf("post_quant_conv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load latent BatchNorm (affine=False, so no weight/bias)
|
||||
bnMean, err := weights.GetTensor("bn.running_mean")
|
||||
if err != nil {
|
||||
return fmt.Errorf("bn.running_mean: %w", err)
|
||||
}
|
||||
bnVar, err := weights.GetTensor("bn.running_var")
|
||||
if err != nil {
|
||||
return fmt.Errorf("bn.running_var: %w", err)
|
||||
}
|
||||
m.LatentBN = &BatchNorm2D{
|
||||
RunningMean: bnMean,
|
||||
RunningVar: bnVar,
|
||||
Weight: nil, // affine=False
|
||||
Bias: nil, // affine=False
|
||||
Eps: cfg.BatchNormEps,
|
||||
Momentum: cfg.BatchNormMomentum,
|
||||
}
|
||||
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadVAEMidBlock loads the mid block.
|
||||
func loadVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
||||
resnet1, err := loadResnetBlock2D(weights, prefix+".resnets.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attention, err := loadVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resnet2, err := loadResnetBlock2D(weights, prefix+".resnets.1", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &VAEMidBlock{
|
||||
Resnet1: resnet1,
|
||||
Attention: attention,
|
||||
Resnet2: resnet2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// loadResnetBlock2D loads a ResNet block.
|
||||
func loadResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
||||
block := &ResnetBlock2D{
|
||||
Norm1: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
||||
Conv1: &Conv2D{Stride: 1, Padding: 1},
|
||||
Norm2: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
||||
Conv2: &Conv2D{Stride: 1, Padding: 1},
|
||||
ConvShortcut: &Conv2D{Stride: 1, Padding: 0}, // Pre-allocate for optional loading
|
||||
}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If ConvShortcut wasn't loaded (no weights found), nil it out
|
||||
if block.ConvShortcut.Weight == nil {
|
||||
block.ConvShortcut = nil
|
||||
}
|
||||
return block, nil
|
||||
}
|
||||
|
||||
// loadVAEAttentionBlock loads an attention block using LoadModule.
|
||||
func loadVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
||||
ab := &VAEAttentionBlock{
|
||||
GroupNorm: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
||||
}
|
||||
if err := safetensors.LoadModule(ab, weights, prefix); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ab, nil
|
||||
}
|
||||
|
||||
// loadUpDecoderBlock2D loads an up decoder block.
|
||||
func loadUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
||||
resnets := make([]*ResnetBlock2D, numLayers)
|
||||
for i := int32(0); i < numLayers; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resnets[i] = resnet
|
||||
}
|
||||
|
||||
var upsample *Conv2D
|
||||
if hasUpsample {
|
||||
upsample = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(upsample, weights, prefix+".upsamplers.0.conv"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &UpDecoderBlock2D{
|
||||
ResnetBlocks: resnets,
|
||||
Upsample: upsample,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Patchify converts latents [B, C, H, W] to patches [B, H*W/4, C*4] using 2x2 patches
|
||||
// This is the inverse of the VAE's patchify for feeding to transformer
|
||||
func (vae *AutoencoderKLFlux2) Patchify(latents *mlx.Array) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
patchH := vae.Config.PatchSize[0]
|
||||
patchW := vae.Config.PatchSize[1]
|
||||
|
||||
pH := H / patchH
|
||||
pW := W / patchW
|
||||
|
||||
// [B, C, H, W] -> [B, C, pH, patchH, pW, patchW]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchH, pW, patchW)
|
||||
// [B, C, pH, patchH, pW, patchW] -> [B, pH, pW, C, patchH, patchW]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// [B, pH, pW, C, patchH, patchW] -> [B, pH*pW, C*patchH*patchW]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchH*patchW)
|
||||
}
|
||||
|
||||
// Unpatchify converts patches [B, L, C*4] back to [B, C, H, W]
|
||||
func (vae *AutoencoderKLFlux2) Unpatchify(patches *mlx.Array, pH, pW, C int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
|
||||
patchH := vae.Config.PatchSize[0]
|
||||
patchW := vae.Config.PatchSize[1]
|
||||
|
||||
// [B, pH*pW, C*patchH*patchW] -> [B, pH, pW, C, patchH, patchW]
|
||||
x := mlx.Reshape(patches, B, pH, pW, C, patchH, patchW)
|
||||
// [B, pH, pW, C, patchH, patchW] -> [B, C, pH, patchH, pW, patchW]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// [B, C, pH, patchH, pW, patchW] -> [B, C, H, W]
|
||||
H := pH * patchH
|
||||
W := pW * patchW
|
||||
return mlx.Reshape(x, B, C, H, W)
|
||||
}
|
||||
|
||||
// denormalizePatchified applies inverse batch normalization to patchified latents.
|
||||
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
|
||||
// Output: [B, L, 128] denormalized
|
||||
func (vae *AutoencoderKLFlux2) denormalizePatchified(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[2] // 128
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, C]
|
||||
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
|
||||
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
|
||||
|
||||
// Inverse BN (affine=False): x_denorm = x * sqrt(var + eps) + mean
|
||||
if vae.LatentBN.Bias != nil {
|
||||
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
|
||||
x = mlx.Sub(x, bias)
|
||||
}
|
||||
if vae.LatentBN.Weight != nil {
|
||||
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
|
||||
x = mlx.Div(x, weight)
|
||||
}
|
||||
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
|
||||
x = mlx.Add(x, mean)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Decode decodes latent patches to images.
|
||||
// If Tiling is set, uses tiled decoding to reduce memory for large images.
|
||||
// latents: [B, L, C*4] patchified latents from transformer
|
||||
// pH, pW: patch grid dimensions
|
||||
// Returns: [B, 3, H, W] image tensor
|
||||
func (v *AutoencoderKLFlux2) Decode(latents *mlx.Array, pH, pW int32) *mlx.Array {
|
||||
// Denormalize patchified latents
|
||||
z := v.denormalizePatchified(latents)
|
||||
|
||||
// Unpatchify: [B, L, C*4] -> [B, C, H, W]
|
||||
z = v.Unpatchify(z, pH, pW, v.Config.LatentChannels)
|
||||
|
||||
// Convert NCHW -> NHWC for processing
|
||||
z = mlx.Transpose(z, 0, 2, 3, 1)
|
||||
|
||||
// Use tiled decoding if enabled
|
||||
if v.Tiling != nil {
|
||||
mlx.Eval(z)
|
||||
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
|
||||
}
|
||||
|
||||
// Direct decode (no tiling)
|
||||
h := v.decodeTile(z)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
return h
|
||||
}
|
||||
|
||||
// decodeTile decodes a single latent tile to pixels (internal helper)
|
||||
// z: [B, H, W, C] latent tile in NHWC format
|
||||
// Returns: [B, H*8, W*8, 3] pixel tile in NHWC format (before clipping)
|
||||
func (vae *AutoencoderKLFlux2) decodeTile(z *mlx.Array) *mlx.Array {
|
||||
// Post-quant conv
|
||||
if vae.PostQuantConv != nil {
|
||||
z = vae.PostQuantConv.Forward(z)
|
||||
}
|
||||
|
||||
// Decoder
|
||||
h := vae.DecoderConvIn.Forward(z)
|
||||
h = vae.DecoderMid.Forward(h)
|
||||
|
||||
for _, upBlock := range vae.DecoderUp {
|
||||
h = upBlock.Forward(h)
|
||||
}
|
||||
|
||||
h = vae.DecoderNormOut.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.DecoderConvOut.Forward(h)
|
||||
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.MulScalar(h, 0.5)
|
||||
h = mlx.AddScalar(h, 0.5)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// loadEncoderWeights loads the encoder components for image conditioning
|
||||
func (m *AutoencoderKLFlux2) loadEncoderWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
|
||||
var err error
|
||||
|
||||
// Load encoder conv_in
|
||||
m.EncoderConvIn = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.EncoderConvIn, weights, "encoder.conv_in"); err != nil {
|
||||
return fmt.Errorf("encoder.conv_in: %w", err)
|
||||
}
|
||||
|
||||
// Load encoder down blocks
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.EncoderDown = make([]*DownEncoderBlock2D, numBlocks)
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", i)
|
||||
hasDownsample := i < numBlocks-1
|
||||
m.EncoderDown[i], err = loadDownEncoderBlock2D(weights, prefix, cfg.LayersPerBlock, cfg.NormNumGroups, hasDownsample)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", prefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load encoder mid block
|
||||
m.EncoderMid, err = loadVAEMidBlock(weights, "encoder.mid_block", cfg.NormNumGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder.mid_block: %w", err)
|
||||
}
|
||||
|
||||
// Load encoder conv_norm_out and conv_out
|
||||
m.EncoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
|
||||
if err := safetensors.LoadModule(m.EncoderNormOut, weights, "encoder.conv_norm_out"); err != nil {
|
||||
return fmt.Errorf("encoder.conv_norm_out: %w", err)
|
||||
}
|
||||
|
||||
m.EncoderConvOut = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.EncoderConvOut, weights, "encoder.conv_out"); err != nil {
|
||||
return fmt.Errorf("encoder.conv_out: %w", err)
|
||||
}
|
||||
|
||||
// Load quant_conv (for encoding)
|
||||
if cfg.UseQuantConv {
|
||||
m.QuantConv = &Conv2D{Stride: 1, Padding: 0}
|
||||
if err := safetensors.LoadModule(m.QuantConv, weights, "quant_conv"); err != nil {
|
||||
return fmt.Errorf("quant_conv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadDownEncoderBlock2D loads a down encoder block.
|
||||
func loadDownEncoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasDownsample bool) (*DownEncoderBlock2D, error) {
|
||||
resnets := make([]*ResnetBlock2D, numLayers)
|
||||
for i := int32(0); i < numLayers; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resnets[i] = resnet
|
||||
}
|
||||
|
||||
var downsample *Conv2D
|
||||
if hasDownsample {
|
||||
downsample = &Conv2D{Stride: 2, Padding: 0}
|
||||
if err := safetensors.LoadModule(downsample, weights, prefix+".downsamplers.0.conv"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &DownEncoderBlock2D{
|
||||
ResnetBlocks: resnets,
|
||||
Downsample: downsample,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncodeImage encodes an image to normalized latents.
|
||||
// image: [B, 3, H, W] image tensor in [-1, 1]
|
||||
// Returns: [B, L, C*4] patchified normalized latents
|
||||
func (vae *AutoencoderKLFlux2) EncodeImage(image *mlx.Array) *mlx.Array {
|
||||
// Convert NCHW -> NHWC
|
||||
x := mlx.Transpose(image, 0, 2, 3, 1)
|
||||
|
||||
// Encoder
|
||||
h := vae.EncoderConvIn.Forward(x)
|
||||
|
||||
for _, downBlock := range vae.EncoderDown {
|
||||
h = downBlock.Forward(h)
|
||||
}
|
||||
|
||||
h = vae.EncoderMid.Forward(h)
|
||||
h = vae.EncoderNormOut.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.EncoderConvOut.Forward(h)
|
||||
|
||||
// Quant conv outputs [B, H, W, 2*latent_channels] (mean + logvar)
|
||||
if vae.QuantConv != nil {
|
||||
h = vae.QuantConv.Forward(h)
|
||||
}
|
||||
|
||||
// Take only the mean (first latent_channels) - deterministic encoding
|
||||
// h is [B, H, W, 64] -> take first 32 channels for mean
|
||||
shape := h.Shape()
|
||||
latentChannels := vae.Config.LatentChannels // 32
|
||||
h = mlx.Slice(h, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], latentChannels})
|
||||
|
||||
// Convert NHWC -> NCHW for patchifying
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
|
||||
// Patchify: [B, C, H, W] -> [B, L, C*4]
|
||||
h = vae.Patchify(h)
|
||||
|
||||
// Apply BatchNorm on patchified latents [B, L, 128]
|
||||
// The BatchNorm has 128 channels matching the patchified dimension
|
||||
h = vae.normalizePatchified(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// normalizePatchified applies batch normalization to patchified latents.
|
||||
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
|
||||
// Output: [B, L, 128] normalized
|
||||
func (vae *AutoencoderKLFlux2) normalizePatchified(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[2] // 128
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, C]
|
||||
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
|
||||
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
|
||||
|
||||
// Normalize: (x - mean) / sqrt(var + eps)
|
||||
xNorm := mlx.Sub(x, mean)
|
||||
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
|
||||
|
||||
// Scale and shift (only if affine=True)
|
||||
if vae.LatentBN.Weight != nil {
|
||||
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if vae.LatentBN.Bias != nil {
|
||||
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
@@ -1,390 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models.
|
||||
package qwen3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Config holds Qwen3 text encoder configuration
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
}
|
||||
|
||||
// Attention implements Qwen3 attention with QK norms
|
||||
type Attention struct {
|
||||
QProj nn.LinearLayer `weight:"q_proj"`
|
||||
KProj nn.LinearLayer `weight:"k_proj"`
|
||||
VProj nn.LinearLayer `weight:"v_proj"`
|
||||
OProj nn.LinearLayer `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
RopeTheta float32
|
||||
}
|
||||
|
||||
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
|
||||
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
freqsArr := make([]float32, half)
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
posArr := make([]float32, seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
posArr[i] = float32(i)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{seqLen})
|
||||
|
||||
posExpanded := mlx.Reshape(pos, seqLen, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
args := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
cosVals := mlx.Cos(args)
|
||||
sinVals := mlx.Sin(args)
|
||||
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
|
||||
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
|
||||
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
|
||||
|
||||
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
|
||||
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
|
||||
|
||||
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
|
||||
}
|
||||
|
||||
// Forward computes attention with causal masking and optional padding mask
|
||||
func (attn *Attention) Forward(x *mlx.Array, mask *mlx.Array, maskMode string) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
q := attn.QProj.Forward(x)
|
||||
k := attn.KProj.Forward(x)
|
||||
v := attn.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
|
||||
q = attn.QNorm.Forward(q, 1e-6)
|
||||
k = attn.KNorm.Forward(k, 1e-6)
|
||||
|
||||
q = applyRoPEQwen3(q, L, attn.RopeTheta)
|
||||
k = applyRoPEQwen3(k, L, attn.RopeTheta)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
k = repeatKV(k, repeats)
|
||||
v = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, attn.Scale, maskMode, mask, nil)
|
||||
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
out = attn.OProj.Forward(out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// MLP implements Qwen3 SwiGLU MLP
|
||||
type MLP struct {
|
||||
GateProj nn.LinearLayer `weight:"gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := m.GateProj.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := m.UpProj.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
return m.DownProj.Forward(h)
|
||||
}
|
||||
|
||||
// Block represents a single Qwen3 transformer block
|
||||
type Block struct {
|
||||
Attention *Attention `weight:"self_attn"`
|
||||
MLP *MLP `weight:"mlp"`
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the Qwen3 block
|
||||
func (qb *Block) Forward(x *mlx.Array, eps float32, mask *mlx.Array, maskMode string) *mlx.Array {
|
||||
h := qb.InputLayerNorm.Forward(x, eps)
|
||||
attnOut := qb.Attention.Forward(h, mask, maskMode)
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
h = qb.PostAttnLayerNorm.Forward(x, eps)
|
||||
mlpOut := qb.MLP.Forward(h)
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// TextEncoder is the full Qwen3 encoder
|
||||
type TextEncoder struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Block `weight:"model.layers"`
|
||||
FinalNorm *nn.RMSNorm `weight:"model.norm"`
|
||||
*Config
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from ollama blob storage.
|
||||
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
|
||||
fmt.Print(" Loading text encoder... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg Config
|
||||
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
m.Layers = make([]*Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *TextEncoder) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *TextEncoder) initComputedFields() {
|
||||
cfg := m.Config
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
for _, block := range m.Layers {
|
||||
// Attention
|
||||
block.Attention.NHeads = cfg.NumAttentionHeads
|
||||
block.Attention.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.Attention.HeadDim = cfg.HeadDim
|
||||
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.Attention.RopeTheta = cfg.RopeTheta
|
||||
block.Attention.QNorm.Eps = cfg.RMSNormEps
|
||||
block.Attention.KNorm.Eps = cfg.RMSNormEps
|
||||
// Block norms
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
}
|
||||
|
||||
// Forward encodes text tokens with provided attention mask (LxL) and mask mode.
|
||||
func (te *TextEncoder) Forward(tokens *mlx.Array, attnMask *mlx.Array, maskMode string) *mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
for _, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps, attnMask, maskMode)
|
||||
}
|
||||
|
||||
// Apply final RMS norm
|
||||
h = te.FinalNorm.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// ForwardWithLayerOutputs encodes text tokens and returns hidden states from specified layers.
|
||||
// This is used by Flux2 which needs embeddings from specific intermediate layers.
|
||||
func (te *TextEncoder) ForwardWithLayerOutputs(tokens *mlx.Array, layerIndices []int, attnMask *mlx.Array, maskMode string) []*mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
outputs := make([]*mlx.Array, len(layerIndices))
|
||||
layerSet := make(map[int]int)
|
||||
for i, idx := range layerIndices {
|
||||
layerSet[idx] = i
|
||||
}
|
||||
|
||||
for i, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps, attnMask, maskMode)
|
||||
if outIdx, ok := layerSet[i]; ok {
|
||||
outputs[outIdx] = h
|
||||
}
|
||||
}
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
// ApplyChatTemplate wraps prompt in Qwen3 chat format.
|
||||
// If think is true, adds the <think></think> block after the assistant tag
|
||||
// (matches tokenizer.apply_chat_template with enable_thinking=False in Python).
|
||||
func ApplyChatTemplate(prompt string, think bool) string {
|
||||
base := "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
|
||||
if think {
|
||||
return base + "<think>\n\n</think>\n\n"
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// EncodePrompt encodes a text prompt using the tokenizer and encoder.
|
||||
// If think is true, includes the <think></think> block in the chat template.
|
||||
func (te *TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int, think bool) (*mlx.Array, *mlx.Array) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt, think)
|
||||
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
maskData := make([]float32, maxLen)
|
||||
for i := 0; i < len(tokens); i++ {
|
||||
maskData[i] = 1.0
|
||||
}
|
||||
|
||||
// Get PAD token (different from EOS for Qwen3)
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
|
||||
paddedTokens := make([]int32, maxLen)
|
||||
copy(paddedTokens, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
paddedTokens[i] = padToken
|
||||
}
|
||||
|
||||
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
|
||||
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
|
||||
|
||||
// Build combined causal + PAD mask [L, L]
|
||||
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
|
||||
L := int32(maxLen)
|
||||
validLen := int32(len(tokens))
|
||||
combinedMaskData := make([]float32, L*L)
|
||||
negInf := float32(-1e9)
|
||||
for i := int32(0); i < L; i++ {
|
||||
for j := int32(0); j < L; j++ {
|
||||
idx := i*L + j
|
||||
if j <= i && j < validLen {
|
||||
combinedMaskData[idx] = 0
|
||||
} else {
|
||||
combinedMaskData[idx] = negInf
|
||||
}
|
||||
}
|
||||
}
|
||||
maskMat := mlx.NewArray(combinedMaskData, []int32{L, L})
|
||||
|
||||
embeddings := te.Forward(tokensArr, maskMat, "")
|
||||
|
||||
return embeddings, maskArr
|
||||
}
|
||||
|
||||
// EncodePromptWithLayers encodes a text prompt and returns embeddings from specified layers.
|
||||
// Used by Flux2 which concatenates embeddings from multiple intermediate layers.
|
||||
// If think is true, includes the <think></think> block in the chat template.
|
||||
// Returns embeddings and padded sequence length.
|
||||
func (te *TextEncoder) EncodePromptWithLayers(tok *tokenizer.Tokenizer, prompt string, maxLen int, layerIndices []int, think bool) (*mlx.Array, int32) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt, think)
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
// Pad to maxLen
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
padded := make([]int32, maxLen)
|
||||
copy(padded, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
padded[i] = padToken
|
||||
}
|
||||
tokensArr := mlx.NewArrayInt32(padded, []int32{1, int32(maxLen)})
|
||||
|
||||
// Build combined causal + PAD mask [L, L]
|
||||
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
|
||||
// This combines causal masking with PAD token masking
|
||||
L := int32(maxLen)
|
||||
validLen := int32(len(tokens))
|
||||
maskData := make([]float32, L*L)
|
||||
negInf := float32(-1e9)
|
||||
for i := int32(0); i < L; i++ {
|
||||
for j := int32(0); j < L; j++ {
|
||||
idx := i*L + j
|
||||
if j <= i && j < validLen {
|
||||
maskData[idx] = 0 // allowed: causal OK and not PAD
|
||||
} else {
|
||||
maskData[idx] = negInf // blocked: future or PAD
|
||||
}
|
||||
}
|
||||
}
|
||||
maskMat := mlx.NewArray(maskData, []int32{L, L})
|
||||
|
||||
layerOutputs := te.ForwardWithLayerOutputs(tokensArr, layerIndices, maskMat, "")
|
||||
|
||||
// Concatenate layer outputs along the hidden dimension
|
||||
// Each output is [B, L, hidden_dim], result is [B, L, num_layers * hidden_dim]
|
||||
embeddings := mlx.Concatenate(layerOutputs, 2)
|
||||
|
||||
// Return embeddings and padded length
|
||||
return embeddings, int32(maxLen)
|
||||
}
|
||||
@@ -17,13 +17,13 @@ import (
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
|
||||
// Layer caching (DeepCache/Learning-to-Cache speedup)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
@@ -31,6 +31,9 @@ type GenerateConfig struct {
|
||||
CacheLayers int // Number of shallow layers to cache (default: 25)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
@@ -114,7 +117,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
@@ -126,7 +129,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
|
||||
@@ -18,15 +18,18 @@ import (
|
||||
// GenerateConfig holds all options for image editing.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
||||
Width int32 // Output width (default: from input image)
|
||||
Height int32 // Output height (default: from input image)
|
||||
Steps int // Denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
||||
Width int32 // Output width (default: from input image)
|
||||
Height int32 // Output height (default: from input image)
|
||||
Steps int // Denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image-Edit diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
|
||||
@@ -3,17 +3,287 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen3"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Re-export types from shared qwen3 package for backwards compatibility
|
||||
type (
|
||||
Qwen3Config = qwen3.Config
|
||||
Qwen3Attention = qwen3.Attention
|
||||
Qwen3MLP = qwen3.MLP
|
||||
Qwen3Block = qwen3.Block
|
||||
Qwen3TextEncoder = qwen3.TextEncoder
|
||||
)
|
||||
// Qwen3Config holds Qwen3 text encoder configuration
|
||||
type Qwen3Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen3 attention with QK norms
|
||||
type Qwen3Attention struct {
|
||||
QProj nn.LinearLayer `weight:"q_proj"`
|
||||
KProj nn.LinearLayer `weight:"k_proj"`
|
||||
VProj nn.LinearLayer `weight:"v_proj"`
|
||||
OProj nn.LinearLayer `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
RopeTheta float32
|
||||
}
|
||||
|
||||
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
|
||||
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
freqsArr := make([]float32, half)
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
posArr := make([]float32, seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
posArr[i] = float32(i)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{seqLen})
|
||||
|
||||
posExpanded := mlx.Reshape(pos, seqLen, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
args := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
cosVals := mlx.Cos(args)
|
||||
sinVals := mlx.Sin(args)
|
||||
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
|
||||
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
|
||||
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
|
||||
|
||||
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
|
||||
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
|
||||
|
||||
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
|
||||
}
|
||||
|
||||
// Forward computes attention with causal masking
|
||||
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
q := attn.QProj.Forward(x)
|
||||
k := attn.KProj.Forward(x)
|
||||
v := attn.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
|
||||
q = attn.QNorm.Forward(q, 1e-6)
|
||||
k = attn.KNorm.Forward(k, 1e-6)
|
||||
|
||||
q = applyRoPEQwen3(q, L, attn.RopeTheta)
|
||||
k = applyRoPEQwen3(k, L, attn.RopeTheta)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
k = repeatKV(k, repeats)
|
||||
v = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
|
||||
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
out = attn.OProj.Forward(out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// Qwen3MLP implements Qwen3 SwiGLU MLP
|
||||
type Qwen3MLP struct {
|
||||
GateProj nn.LinearLayer `weight:"gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := m.GateProj.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := m.UpProj.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
return m.DownProj.Forward(h)
|
||||
}
|
||||
|
||||
// Qwen3Block represents a single Qwen3 transformer block
|
||||
type Qwen3Block struct {
|
||||
Attention *Qwen3Attention `weight:"self_attn"`
|
||||
MLP *Qwen3MLP `weight:"mlp"`
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the Qwen3 block
|
||||
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
h := qb.InputLayerNorm.Forward(x, eps)
|
||||
attnOut := qb.Attention.Forward(h)
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
h = qb.PostAttnLayerNorm.Forward(x, eps)
|
||||
mlpOut := qb.MLP.Forward(h)
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
|
||||
type Qwen3TextEncoder struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Qwen3Block `weight:"model.layers"`
|
||||
FinalNorm *nn.RMSNorm `weight:"model.norm"`
|
||||
*Qwen3Config
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from ollama blob storage.
|
||||
func (m *Qwen3TextEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading text encoder... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg Qwen3Config
|
||||
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Qwen3Config = &cfg
|
||||
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *Qwen3TextEncoder) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *Qwen3TextEncoder) initComputedFields() {
|
||||
cfg := m.Qwen3Config
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
for _, block := range m.Layers {
|
||||
// Attention
|
||||
block.Attention.NHeads = cfg.NumAttentionHeads
|
||||
block.Attention.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.Attention.HeadDim = cfg.HeadDim
|
||||
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.Attention.RopeTheta = cfg.RopeTheta
|
||||
block.Attention.QNorm.Eps = cfg.RMSNormEps
|
||||
block.Attention.KNorm.Eps = cfg.RMSNormEps
|
||||
// Block norms
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
for _, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps)
|
||||
}
|
||||
|
||||
// Apply final RMS norm
|
||||
h = te.FinalNorm.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// ApplyChatTemplate wraps prompt in Qwen3 chat format
|
||||
var ApplyChatTemplate = qwen3.ApplyChatTemplate
|
||||
func ApplyChatTemplate(prompt string) string {
|
||||
return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
|
||||
}
|
||||
|
||||
// EncodePrompt encodes a text prompt using the tokenizer and encoder
|
||||
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt)
|
||||
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
maskData := make([]float32, maxLen)
|
||||
for i := 0; i < len(tokens); i++ {
|
||||
maskData[i] = 1.0
|
||||
}
|
||||
|
||||
// Get PAD token (different from EOS for Qwen3)
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
|
||||
paddedTokens := make([]int32, maxLen)
|
||||
copy(paddedTokens, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
paddedTokens[i] = padToken
|
||||
}
|
||||
|
||||
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
|
||||
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
|
||||
|
||||
embeddings := te.Forward(tokensArr)
|
||||
|
||||
return embeddings, maskArr
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
)
|
||||
|
||||
// VAEConfig holds VAE decoder configuration
|
||||
@@ -637,9 +636,6 @@ type VAEDecoder struct {
|
||||
UpBlocks []*UpDecoderBlock2D
|
||||
ConvNormOut *GroupNormLayer
|
||||
ConvOut *Conv2D
|
||||
|
||||
// Tiling configuration (nil = no tiling)
|
||||
Tiling *vae.TilingConfig
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from ollama blob storage.
|
||||
@@ -734,60 +730,45 @@ func (m *VAEDecoder) loadWeights(weights safetensors.WeightSource, cfg *VAEConfi
|
||||
|
||||
// Decode decodes latents to images.
|
||||
// Input latents are in NCHW format, output is in NCHW format.
|
||||
// If Tiling is set, uses tiled decoding to reduce memory for large images.
|
||||
func (v *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
// Internally uses NHWC format (MLX native) for all operations.
|
||||
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
// Scale latents
|
||||
z := mlx.DivScalar(latents, v.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, v.Config.ShiftFactor)
|
||||
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
||||
// Convert NCHW -> NHWC for internal processing
|
||||
z = mlx.Transpose(z, 0, 2, 3, 1)
|
||||
|
||||
// Use tiled decoding if enabled
|
||||
if v.Tiling != nil {
|
||||
mlx.Eval(z)
|
||||
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
|
||||
}
|
||||
|
||||
// Direct decode
|
||||
h := v.decodeTile(z)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
// Convert NHWC -> NCHW for output
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
mlx.Eval(h)
|
||||
return h
|
||||
}
|
||||
|
||||
// decodeTile decodes a single latent tile to pixels.
|
||||
// Input: [B, H, W, C] latent tile in NHWC format (already scaled)
|
||||
// Output: [B, H*8, W*8, 3] pixel tile in NHWC format
|
||||
func (v *VAEDecoder) decodeTile(z *mlx.Array) *mlx.Array {
|
||||
h := v.ConvIn.Forward(z)
|
||||
h := vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
|
||||
prev := h
|
||||
h = v.MidBlock.Forward(h)
|
||||
h = vae.MidBlock.Forward(h)
|
||||
prev.Free()
|
||||
|
||||
for _, upBlock := range v.UpBlocks {
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
prev = h
|
||||
h = upBlock.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
prev = h
|
||||
h = v.ConvNormOut.Forward(h)
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
|
||||
prev.Free()
|
||||
|
||||
prev = h
|
||||
h = mlx.SiLU(h)
|
||||
h = v.ConvOut.Forward(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
mlx.Eval(h)
|
||||
prev.Free()
|
||||
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.MulScalar(h, 0.5)
|
||||
h = mlx.AddScalar(h, 0.5)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
|
||||
// Convert NHWC -> NCHW for output
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
mlx.Eval(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -12,20 +12,19 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// TeaCache options (timestep embedding aware caching)
|
||||
TeaCache bool // TeaCache is always enabled for faster inference
|
||||
@@ -35,6 +34,9 @@ type GenerateConfig struct {
|
||||
FusedQKV bool // Enable fused QKV projection (default: false)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Z-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelName string
|
||||
@@ -91,7 +93,7 @@ func (m *Model) Load(modelName string) error {
|
||||
|
||||
// Load text encoder
|
||||
m.TextEncoder = &Qwen3TextEncoder{}
|
||||
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
|
||||
if err := m.TextEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
@@ -137,7 +139,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
@@ -149,7 +151,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
@@ -177,16 +179,9 @@ func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*m
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements runner.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
@@ -227,9 +222,9 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
// Text encoding with padding to multiple of 32
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512, false)
|
||||
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512)
|
||||
if useCFG {
|
||||
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512, false)
|
||||
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512)
|
||||
}
|
||||
|
||||
// Pad both to same length (multiple of 32)
|
||||
@@ -453,11 +448,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
teaCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode - enable tiling for larger images to reduce memory
|
||||
// VAE attention is O(n²) on latent pixels, tiling helps significantly
|
||||
if latentH > 64 || latentW > 64 {
|
||||
m.VAEDecoder.Tiling = vae.DefaultTilingConfig()
|
||||
}
|
||||
// VAE decode
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
)
|
||||
|
||||
@@ -41,15 +40,10 @@ type Response struct {
|
||||
Total int `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// ImageModel is the interface for image generation models
|
||||
type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// Server holds the model and handles requests
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
model ImageModel
|
||||
model *zimage.Model
|
||||
modelName string
|
||||
}
|
||||
|
||||
@@ -86,25 +80,10 @@ func Execute(args []string) error {
|
||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
||||
}
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType := imagegen.DetectModelType(*modelName)
|
||||
slog.Info("detected model type", "type", modelType)
|
||||
|
||||
var model ImageModel
|
||||
switch modelType {
|
||||
case "Flux2KleinPipeline":
|
||||
m := &flux2.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
model = m
|
||||
default:
|
||||
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
|
||||
m := &zimage.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
model = m
|
||||
// Load model
|
||||
model := &zimage.Model{}
|
||||
if err := model.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
@@ -180,19 +159,26 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate image using the common interface
|
||||
// Generate image
|
||||
ctx := r.Context()
|
||||
enc := json.NewEncoder(w)
|
||||
|
||||
// Progress callback streams step updates
|
||||
progress := func(step, total int) {
|
||||
resp := Response{Step: step, Total: total}
|
||||
enc.Encode(resp)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: req.Seed,
|
||||
Progress: func(step, total int) {
|
||||
resp := Response{
|
||||
Step: step,
|
||||
Total: total,
|
||||
Done: false,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
|
||||
@@ -17,24 +17,6 @@ type WeightSource interface {
|
||||
GetTensor(name string) (*mlx.Array, error)
|
||||
ListTensors() []string
|
||||
HasTensor(name string) bool
|
||||
Quantization() string // Returns "FP4", "FP8", or ""
|
||||
}
|
||||
|
||||
// quantizationParams returns groupSize, bits, mode for a quantization type.
|
||||
// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
|
||||
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "FP4":
|
||||
return 32, 4, "affine"
|
||||
default:
|
||||
return 32, 8, "affine" // FP8 or unknown
|
||||
}
|
||||
}
|
||||
|
||||
// Transformer allows structs to transform weight arrays before assignment.
|
||||
// Implement this to apply operations like transpose during loading.
|
||||
type Transformer interface {
|
||||
Transform(field string, arr *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// LoadModule loads weights into a struct using reflection and struct tags.
|
||||
@@ -154,10 +136,6 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Transform before assigning if parent implements Transformer
|
||||
if t, ok := v.Addr().Interface().(Transformer); ok {
|
||||
arr = t.Transform(field.Name, arr)
|
||||
}
|
||||
fieldVal.Set(reflect.ValueOf(arr))
|
||||
continue
|
||||
}
|
||||
@@ -245,21 +223,19 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
groupSize, bits, mode := quantizationParams(weights.Quantization())
|
||||
|
||||
if mlx.MetalIsAvailable() {
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
GroupSize: 32,
|
||||
Bits: 8,
|
||||
Mode: "affine",
|
||||
}, nil
|
||||
}
|
||||
|
||||
dequantized := mlx.Dequantize(weight, scales, qbiases, groupSize, bits, mode)
|
||||
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
|
||||
return nn.NewLinear(dequantized, bias), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -298,11 +298,6 @@ func (mw *ModelWeights) HasTensor(name string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// Quantization returns empty string for directory-based weights (not quantized).
|
||||
func (mw *ModelWeights) Quantization() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// ReleaseAll releases all cached native file handles.
|
||||
func (mw *ModelWeights) ReleaseAll() {
|
||||
for path, native := range mw.nativeCache {
|
||||
|
||||
@@ -510,11 +510,7 @@ func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) {
|
||||
t.vocab.Merges[merge] = i
|
||||
}
|
||||
|
||||
// Add all added_tokens to vocabulary and special tokens map.
|
||||
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
|
||||
// they bypass BPE and get their own token ID. The "special" flag just indicates
|
||||
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
|
||||
// to treat all added_tokens as special to match HuggingFace behavior.
|
||||
// Add special tokens to vocabulary
|
||||
for _, tok := range raw.AddedTokens {
|
||||
if int(tok.ID) >= len(t.vocab.Values) {
|
||||
newValues := make([]string, tok.ID+1)
|
||||
@@ -522,7 +518,9 @@ func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) {
|
||||
t.vocab.Values = newValues
|
||||
}
|
||||
t.vocab.Values[tok.ID] = tok.Content
|
||||
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
|
||||
if tok.Special {
|
||||
t.specialTokens[tok.Content] = tok.ID
|
||||
}
|
||||
}
|
||||
|
||||
// Load special token configuration from companion files
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package vae provides shared utilities for VAE (Variational Autoencoder) operations.
|
||||
package vae
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TilingConfig holds configuration for tiled VAE decoding.
|
||||
// This is a general technique to reduce memory usage when decoding large latents.
|
||||
type TilingConfig struct {
|
||||
TileSize int32 // Tile size in latent space (e.g., 64 latent → 512 pixels for 8x VAE)
|
||||
Overlap int32 // Overlap in latent space (e.g., 16 latent = 25% of 64)
|
||||
}
|
||||
|
||||
// DefaultTilingConfig returns reasonable defaults matching diffusers.
|
||||
// tile_latent_min_size=64, tile_overlap_factor=0.25
|
||||
func DefaultTilingConfig() *TilingConfig {
|
||||
return &TilingConfig{
|
||||
TileSize: 64, // 64 latent pixels
|
||||
Overlap: 16, // 25% overlap
|
||||
}
|
||||
}
|
||||
|
||||
// decodedTile holds a decoded tile's pixel data and dimensions
|
||||
type decodedTile struct {
|
||||
data []float32
|
||||
height int32
|
||||
width int32
|
||||
}
|
||||
|
||||
// DecodeTiled decodes latents using tiled processing with overlap blending.
|
||||
// This reduces memory usage for large images by processing in overlapping tiles.
|
||||
//
|
||||
// Parameters:
|
||||
// - latents: [1, H, W, C] latent tensor in NHWC format
|
||||
// - cfg: tiling configuration (tile size and overlap)
|
||||
// - decoder: function to decode a single tile [1, H, W, C] -> [1, H*scale, W*scale, 3]
|
||||
//
|
||||
// Returns: [1, 3, H*scale, W*scale] decoded image in NCHW format
|
||||
func DecodeTiled(latents *mlx.Array, cfg *TilingConfig, decoder func(*mlx.Array) *mlx.Array) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
H := shape[1] // latent height
|
||||
W := shape[2] // latent width
|
||||
C := shape[3]
|
||||
|
||||
tileLatentSize := cfg.TileSize
|
||||
overlapLatent := cfg.Overlap
|
||||
|
||||
// If image is small enough, just decode normally
|
||||
if H <= tileLatentSize && W <= tileLatentSize {
|
||||
decoded := decoder(latents)
|
||||
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
|
||||
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
|
||||
decoded = mlx.Transpose(decoded, 0, 3, 1, 2) // NHWC -> NCHW
|
||||
return decoded
|
||||
}
|
||||
|
||||
// Calculate tiling parameters (matching diffusers)
|
||||
overlapSize := tileLatentSize - overlapLatent // stride in latent space
|
||||
|
||||
// Blend extent in pixel space (assumes 8x upscale, adjust if needed)
|
||||
// For other scale factors, this could be made configurable
|
||||
tileSampleSize := tileLatentSize * 8 // tile size in pixels after 8x upscale
|
||||
blendExtent := overlapLatent * 8 // blend region in pixels
|
||||
rowLimit := tileSampleSize - blendExtent // non-overlapping region per tile
|
||||
|
||||
// Phase 1: Decode all tiles and store in 2D grid
|
||||
var rows [][]decodedTile
|
||||
|
||||
for i := int32(0); i < H; i += overlapSize {
|
||||
var row []decodedTile
|
||||
for j := int32(0); j < W; j += overlapSize {
|
||||
// Extract tile (may be smaller at edges)
|
||||
i2 := min(i+tileLatentSize, H)
|
||||
j2 := min(j+tileLatentSize, W)
|
||||
|
||||
tile := mlx.Slice(latents, []int32{0, i, j, 0}, []int32{1, i2, j2, C})
|
||||
decoded := decoder(tile)
|
||||
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
|
||||
mlx.Eval(decoded)
|
||||
|
||||
decodedShape := decoded.Shape()
|
||||
tileH := decodedShape[1]
|
||||
tileW := decodedShape[2]
|
||||
tileData := decoded.Data()
|
||||
decoded.Free()
|
||||
|
||||
row = append(row, decodedTile{data: tileData, height: tileH, width: tileW})
|
||||
}
|
||||
rows = append(rows, row)
|
||||
}
|
||||
|
||||
// Phase 2: Blend adjacent tiles (modifies in place)
|
||||
for i := range rows {
|
||||
for j := range rows[i] {
|
||||
tile := &rows[i][j]
|
||||
|
||||
// Blend with tile above
|
||||
if i > 0 {
|
||||
above := &rows[i-1][j]
|
||||
blendV(above, tile, blendExtent)
|
||||
}
|
||||
|
||||
// Blend with tile to the left
|
||||
if j > 0 {
|
||||
left := &rows[i][j-1]
|
||||
blendH(left, tile, blendExtent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Calculate crop dimensions for each tile
|
||||
colWidths := make([]int32, len(rows[0]))
|
||||
for j := range rows[0] {
|
||||
keepW := rowLimit
|
||||
if int32(j+1)*overlapSize >= W {
|
||||
keepW = rows[0][j].width
|
||||
}
|
||||
colWidths[j] = keepW
|
||||
}
|
||||
|
||||
rowHeights := make([]int32, len(rows))
|
||||
for i := range rows {
|
||||
keepH := rowLimit
|
||||
if int32(i+1)*overlapSize >= H {
|
||||
keepH = rows[i][0].height
|
||||
}
|
||||
rowHeights[i] = keepH
|
||||
}
|
||||
|
||||
// Calculate total dimensions
|
||||
var totalW, totalH int32
|
||||
for _, w := range colWidths {
|
||||
totalW += w
|
||||
}
|
||||
for _, h := range rowHeights {
|
||||
totalH += h
|
||||
}
|
||||
|
||||
// Phase 4: Assemble final image by interleaving tiles row-by-row
|
||||
finalData := make([]float32, totalH*totalW*3)
|
||||
|
||||
dstY := int32(0)
|
||||
for i, row := range rows {
|
||||
keepH := rowHeights[i]
|
||||
|
||||
for y := int32(0); y < keepH; y++ {
|
||||
dstX := int32(0)
|
||||
for j, tile := range row {
|
||||
keepW := colWidths[j]
|
||||
|
||||
for x := int32(0); x < keepW; x++ {
|
||||
for c := int32(0); c < 3; c++ {
|
||||
srcIdx := (y*tile.width + x) * 3 + c
|
||||
dstIdx := ((dstY + y) * totalW + (dstX + x)) * 3 + c
|
||||
finalData[dstIdx] = tile.data[srcIdx]
|
||||
}
|
||||
}
|
||||
dstX += keepW
|
||||
}
|
||||
}
|
||||
dstY += keepH
|
||||
}
|
||||
|
||||
// Create mlx array [1, H, W, 3] then transpose to NCHW [1, 3, H, W]
|
||||
result := mlx.NewArray(finalData, []int32{1, totalH, totalW, 3})
|
||||
result = mlx.Transpose(result, 0, 3, 1, 2)
|
||||
result = mlx.ClipScalar(result, 0.0, 1.0, true, true)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// blendV blends the bottom of 'above' tile into top of 'current' tile (vertical blend)
|
||||
// Matches diffusers blend_v formula
|
||||
func blendV(above, current *decodedTile, blendExtent int32) {
|
||||
blend := min(blendExtent, min(above.height, current.height))
|
||||
if blend <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
w := min(above.width, current.width)
|
||||
for y := int32(0); y < blend; y++ {
|
||||
alpha := float32(y) / float32(blend)
|
||||
for x := int32(0); x < w; x++ {
|
||||
for c := int32(0); c < 3; c++ {
|
||||
aboveIdx := ((above.height - blend + y) * above.width + x) * 3 + c
|
||||
currIdx := (y * current.width + x) * 3 + c
|
||||
current.data[currIdx] = above.data[aboveIdx]*(1-alpha) + current.data[currIdx]*alpha
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// blendH blends the right of 'left' tile into left of 'current' tile (horizontal blend)
|
||||
// Matches diffusers blend_h formula
|
||||
func blendH(left, current *decodedTile, blendExtent int32) {
|
||||
blend := min(blendExtent, min(left.width, current.width))
|
||||
if blend <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
h := min(left.height, current.height)
|
||||
for y := int32(0); y < h; y++ {
|
||||
for x := int32(0); x < blend; x++ {
|
||||
alpha := float32(x) / float32(blend)
|
||||
for c := int32(0); c < 3; c++ {
|
||||
leftIdx := (y * left.width + (left.width - blend + x)) * 3 + c
|
||||
currIdx := (y * current.width + x) * 3 + c
|
||||
current.data[currIdx] = left.data[leftIdx]*(1-alpha) + current.data[currIdx]*alpha
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -106,21 +106,6 @@ func (mw *ManifestWeights) HasTensor(name string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// Quantization returns the model's quantization type from model_index.json.
|
||||
// Returns empty string if not quantized or unknown.
|
||||
func (mw *ManifestWeights) Quantization() string {
|
||||
if mw.manifest == nil {
|
||||
return ""
|
||||
}
|
||||
var index struct {
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil {
|
||||
return ""
|
||||
}
|
||||
return index.Quantization
|
||||
}
|
||||
|
||||
// ReleaseAll frees all native handles and clears the tensor cache.
|
||||
func (mw *ManifestWeights) ReleaseAll() {
|
||||
for _, sf := range mw.nativeCache {
|
||||
|
||||
@@ -9,8 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// modelConfig represents the HuggingFace config.json structure
|
||||
@@ -36,22 +35,22 @@ type modelConfig struct {
|
||||
|
||||
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
|
||||
// It reads the config.json layer and returns a map compatible with GGML's KV format.
|
||||
func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
var config modelConfig
|
||||
if err := mf.ReadConfigJSON("config.json", &config); err != nil {
|
||||
if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
// Calculate total tensor bytes from manifest layers
|
||||
var totalBytes int64
|
||||
var tensorCount int64
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
totalBytes += layer.Size
|
||||
tensorCount++
|
||||
}
|
||||
@@ -152,30 +151,27 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
|
||||
|
||||
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
||||
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
||||
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
return getTensorInfoFromManifest(mf)
|
||||
return getTensorInfoFromManifest(manifest)
|
||||
}
|
||||
|
||||
// getTensorInfoFromManifest extracts tensor info from a manifest.
|
||||
// This is separated for testability.
|
||||
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
|
||||
func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
|
||||
var tensors []api.Tensor
|
||||
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType != manifest.MediaTypeImageTensor {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read the safetensors header from the blob
|
||||
blobPath, err := manifest.GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
blobPath := manifest.BlobPath(layer.Digest)
|
||||
info, err := readSafetensorsHeader(blobPath)
|
||||
if err != nil {
|
||||
// Skip tensors we can't read
|
||||
@@ -201,15 +197,15 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
|
||||
// GetSafetensorsDtype returns the quantization type for a safetensors model.
|
||||
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
|
||||
// Otherwise returns the torch_dtype from config.json.
|
||||
func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
func GetSafetensorsDtype(modelName string) (string, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Check if model is quantized by looking for _scale tensors
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if strings.HasSuffix(layer.Name, "_scale") {
|
||||
// Model is quantized - return FP8 (affine quantization)
|
||||
return "FP8", nil
|
||||
@@ -221,7 +217,7 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
}
|
||||
if err := mf.ReadConfigJSON("config.json", &cfg); err != nil {
|
||||
if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
|
||||
return "", fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
func TestBuildModelInfo(t *testing.T) {
|
||||
@@ -451,14 +451,8 @@ func TestParseSafetensorsHeader_Errors(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||
// Create a temp directory for blobs
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
blobDir := filepath.Join(tempDir, "blobs")
|
||||
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Create test tensor blobs
|
||||
tensors := []struct {
|
||||
@@ -469,26 +463,26 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "model.embed_tokens.weight",
|
||||
digest: "sha256:abc123abc123abc123abc123abc123abc123abc123abc123abc123abc123abc0",
|
||||
digest: "sha256:abc123",
|
||||
dtype: "BF16",
|
||||
shape: []int64{262144, 2560},
|
||||
},
|
||||
{
|
||||
name: "model.layers.0.self_attn.q_proj.weight",
|
||||
digest: "sha256:def456def456def456def456def456def456def456def456def456def456def0",
|
||||
digest: "sha256:def456",
|
||||
dtype: "BF16",
|
||||
shape: []int64{2560, 2560},
|
||||
},
|
||||
{
|
||||
name: "model.norm.weight",
|
||||
digest: "sha256:789789789789789789789789789789789789789789789789789789789789abc0",
|
||||
digest: "sha256:ghi789",
|
||||
dtype: "F32",
|
||||
shape: []int64{2560},
|
||||
},
|
||||
}
|
||||
|
||||
// Create blob files
|
||||
var layers []manifest.Layer
|
||||
var layers []imagegen.ManifestLayer
|
||||
for _, tensor := range tensors {
|
||||
// Create safetensors blob
|
||||
header := map[string]any{
|
||||
@@ -504,17 +498,15 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
// Write blob file using the digest format expected by GetBlobsPath
|
||||
blobPath, err := manifest.GetBlobsPath(tensor.digest)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
// Write blob file
|
||||
blobName := "sha256-" + tensor.digest[7:]
|
||||
blobPath := filepath.Join(tempDir, blobName)
|
||||
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob: %v", err)
|
||||
}
|
||||
|
||||
layers = append(layers, manifest.Layer{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
layers = append(layers, imagegen.ManifestLayer{
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Digest: tensor.digest,
|
||||
Size: int64(buf.Len() + 1000), // header + fake data
|
||||
Name: tensor.name,
|
||||
@@ -522,20 +514,21 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add a non-tensor layer (should be skipped)
|
||||
layers = append(layers, manifest.Layer{
|
||||
layers = append(layers, imagegen.ManifestLayer{
|
||||
MediaType: "application/vnd.ollama.image.json",
|
||||
Digest: "sha256:0000000000000000000000000000000000000000000000000000000000000000",
|
||||
Digest: "sha256:config",
|
||||
Size: 100,
|
||||
Name: "config.json",
|
||||
})
|
||||
|
||||
mf := &manifest.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Layers: layers,
|
||||
manifest := &imagegen.ModelManifest{
|
||||
Manifest: &imagegen.Manifest{
|
||||
Layers: layers,
|
||||
},
|
||||
BlobDir: tempDir,
|
||||
}
|
||||
|
||||
result, err := getTensorInfoFromManifest(mf)
|
||||
result, err := getTensorInfoFromManifest(manifest)
|
||||
if err != nil {
|
||||
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user