mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 21:08:16 -05:00
Compare commits
1 Commits
imagegen-a
...
fix-mlx-qu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
582d93ab22 |
105
cmd/cmd.go
105
cmd/cmd.go
@@ -101,6 +101,67 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Check for --experimental flag for safetensors model creation
|
||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||
if experimental {
|
||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||
var reader io.Reader
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) || filename == "" {
|
||||
// No Modelfile specified or found - use default
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
reader = f
|
||||
}
|
||||
|
||||
// Parse the Modelfile
|
||||
modelfile, err := parser.ParseFile(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||
}
|
||||
|
||||
// Extract FROM path and configuration
|
||||
var modelDir string
|
||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
||||
|
||||
for _, cmd := range modelfile.Commands {
|
||||
switch cmd.Name {
|
||||
case "model":
|
||||
modelDir = cmd.Args
|
||||
case "template":
|
||||
mfConfig.Template = cmd.Args
|
||||
case "system":
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
if modelDir == "" {
|
||||
modelDir = "."
|
||||
}
|
||||
|
||||
// Resolve relative paths based on Modelfile location
|
||||
if !filepath.IsAbs(modelDir) && filename != "" {
|
||||
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
|
||||
}
|
||||
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: modelDir,
|
||||
Quantize: quantize,
|
||||
Modelfile: mfConfig,
|
||||
}, p)
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
|
||||
filename, err := getModelfileName(cmd)
|
||||
@@ -136,28 +197,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if this is a tensor model (image generation) and handle it directly
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
modelDir := filepath.Dir(filename)
|
||||
for _, cmd := range modelfile.Commands {
|
||||
if cmd.Name == "model" {
|
||||
if filepath.IsAbs(cmd.Args) {
|
||||
modelDir = cmd.Args
|
||||
} else {
|
||||
modelDir = filepath.Join(filepath.Dir(filename), cmd.Args)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if create.IsTensorModelDir(modelDir) {
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: modelDir,
|
||||
Quantize: quantize,
|
||||
Modelfile: xcreateclient.ExtractModelfileConfig(modelfile),
|
||||
}, p)
|
||||
}
|
||||
|
||||
status := "gathering model components"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
@@ -169,6 +208,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
spinner.Stop()
|
||||
|
||||
req.Model = modelName
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
if quantize != "" {
|
||||
req.Quantize = quantize
|
||||
}
|
||||
@@ -859,11 +899,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
for _, arg := range args {
|
||||
// Unload the model if it's running before deletion
|
||||
if err := loadOrUnloadModel(cmd, &runOptions{
|
||||
Model: arg,
|
||||
Model: args[0],
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}); err != nil {
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", arg)
|
||||
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1775,15 +1815,22 @@ func NewCLI() *cobra.Command {
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Skip server check for experimental mode (writes directly to disk)
|
||||
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
|
||||
return nil
|
||||
}
|
||||
return checkServerHeartbeat(cmd, args)
|
||||
},
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
|
||||
@@ -311,8 +311,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &deepseekocr{}
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
case "Glm4MoeLiteForCausalLM":
|
||||
conv = &glm4MoeLiteModel{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type glm4MoeLiteModel struct {
|
||||
ModelParameters
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
|
||||
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
|
||||
KVLoraRank uint32 `json:"kv_lora_rank"`
|
||||
QLoraRank uint32 `json:"q_lora_rank"`
|
||||
VHeadDim uint32 `json:"v_head_dim"`
|
||||
|
||||
ExpertCount uint32 `json:"n_routed_experts"`
|
||||
ExpertSharedCount uint32 `json:"n_shared_experts"`
|
||||
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
|
||||
ExpertWeightsNorm bool `json:"norm_topk_prob"`
|
||||
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
|
||||
|
||||
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "glm4moelite"
|
||||
kv["general.type"] = "model"
|
||||
kv["glm4moelite.block_count"] = p.HiddenLayers
|
||||
|
||||
numHeads := p.NumAttentionHeads
|
||||
numKVHeads := p.NumKeyValueHeads
|
||||
|
||||
kv["glm4moelite.attention.head_count"] = numHeads
|
||||
kv["glm4moelite.attention.head_count_kv"] = numKVHeads
|
||||
kv["glm4moelite.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
|
||||
kv["glm4moelite.attention.kv_lora_rank"] = p.KVLoraRank
|
||||
kv["glm4moelite.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
kv["glm4moelite.attention.q_lora_rank"] = p.QLoraRank
|
||||
kv["glm4moelite.attention.value_length"] = p.VHeadDim
|
||||
kv["glm4moelite.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["glm4moelite.embedding_length"] = p.HiddenSize
|
||||
kv["glm4moelite.expert_count"] = p.ExpertCount
|
||||
kv["glm4moelite.expert_feed_forward_length"] = p.ExpertIntermediateSize
|
||||
kv["glm4moelite.expert_shared_count"] = p.ExpertSharedCount
|
||||
|
||||
kv["glm4moelite.expert_gating_func"] = uint32(2)
|
||||
kv["glm4moelite.expert_used_count"] = p.ExpertUsedCount
|
||||
kv["glm4moelite.expert_weights_norm"] = p.ExpertWeightsNorm
|
||||
kv["glm4moelite.expert_weights_scale"] = p.ExpertWeightsScale
|
||||
kv["glm4moelite.feed_forward_length"] = p.IntermediateSize
|
||||
kv["glm4moelite.leading_dense_block_count"] = p.LeadingDenseBlockCount
|
||||
|
||||
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
|
||||
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
|
||||
|
||||
kv["tokenizer.ggml.pre"] = "glm4"
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
|
||||
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
|
||||
"self_attn.kv_b_proj", "attn_kv_b",
|
||||
"self_attn.q_a_proj", "attn_q_a",
|
||||
"self_attn.q_a_layernorm", "attn_q_a_norm",
|
||||
"self_attn.q_b_proj", "attn_q_b",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, p.HiddenLayers*3)
|
||||
for i := range p.HiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
skipLayer := func(n string, minValue uint32) bool {
|
||||
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(n)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
blkNum, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return uint32(blkNum) >= minValue
|
||||
}
|
||||
|
||||
out, s = mergeTensors(s, merges...)
|
||||
for _, t := range s {
|
||||
// skip any additional layers (such as the Multi-Token Prediction layer)
|
||||
if skipLayer(t.Name(), p.HiddenLayers) {
|
||||
slog.Debug("skipping layer", "name", t.Name())
|
||||
continue
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -269,7 +269,6 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -857,7 +856,6 @@ func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
"gemma3",
|
||||
"glm4moelite",
|
||||
"gptoss", "gpt-oss",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
package glm4moelite
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
normTopKProb bool
|
||||
routedScalingFactor float32
|
||||
|
||||
kvLoraRank,
|
||||
qkNopeHeadDim,
|
||||
qkRopeHeadDim,
|
||||
kqNopeHeadDim,
|
||||
qkHeadDim int
|
||||
qLoraRank int
|
||||
vHeadDim int
|
||||
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads int
|
||||
|
||||
eps,
|
||||
ropeBase float32
|
||||
kqScale float64
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1.0)
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Q *nn.Linear `gguf:"attn_q"`
|
||||
|
||||
QA *nn.Linear `gguf:"attn_q_a"`
|
||||
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
|
||||
QB *nn.Linear `gguf:"attn_q_b"`
|
||||
|
||||
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
if opts.qLoraRank == 0 {
|
||||
query = attn.Q.Forward(ctx, hiddenStates)
|
||||
} else {
|
||||
query = attn.QA.Forward(ctx, hiddenStates)
|
||||
query = attn.QANorm.Forward(ctx, query, opts.eps)
|
||||
query = attn.QB.Forward(ctx, query)
|
||||
}
|
||||
|
||||
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
|
||||
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
|
||||
|
||||
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
|
||||
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
|
||||
kRot := compressedKV.View(ctx,
|
||||
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
|
||||
compressedKV.Stride(1), 1,
|
||||
compressedKV.Stride(1), compressedKV.Dim(1),
|
||||
)
|
||||
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP interface {
|
||||
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
SharedExpert *dense `gguf:",suf:_shexp"`
|
||||
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||
}
|
||||
|
||||
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
hiddenStates = hiddenStates.SILU(ctx, upStates)
|
||||
|
||||
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
experts = experts.Mul(ctx, topKWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
return nextStates
|
||||
}
|
||||
|
||||
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
|
||||
if moe.ExpProbsBias != nil {
|
||||
scores = scores.Add(ctx, moe.ExpProbsBias)
|
||||
}
|
||||
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
|
||||
return topKIndices
|
||||
}
|
||||
|
||||
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
residuals := hiddenStates
|
||||
|
||||
routerLogits := moe.Router.Forward(ctx, hiddenStates)
|
||||
scores := routerLogits.Sigmoid(ctx)
|
||||
topKIndices := moe.topKIndices(ctx, scores, opts)
|
||||
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
|
||||
|
||||
if opts.normTopKProb {
|
||||
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
|
||||
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
}
|
||||
|
||||
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
|
||||
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
|
||||
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type dense struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Attention *Attention
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP MLP
|
||||
}
|
||||
|
||||
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
residual = hiddenStates
|
||||
|
||||
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
layers := make([]Layer, c.Uint("block_count"))
|
||||
|
||||
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
|
||||
for i := range layers {
|
||||
if i < firstDenseLayerIndex {
|
||||
layers[i].MLP = &dense{}
|
||||
} else {
|
||||
layers[i].MLP = &sparse{}
|
||||
}
|
||||
}
|
||||
|
||||
keyLength := int(c.Uint("attention.key_length"))
|
||||
valueLength := int(c.Uint("attention.value_length"))
|
||||
|
||||
kqScale := 1.0 / math.Sqrt(float64(keyLength))
|
||||
|
||||
var pre []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
case "glm4":
|
||||
pre = []string{
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
}
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
pre...,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("expert_weights_norm", true),
|
||||
|
||||
qLoraRank: int(c.Uint("attention.q_lora_rank")),
|
||||
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
|
||||
qkHeadDim: keyLength,
|
||||
vHeadDim: valueLength,
|
||||
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
|
||||
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
|
||||
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||
|
||||
kqScale: kqScale,
|
||||
},
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("glm4moelite", New)
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
|
||||
@@ -1,410 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type glm46ParserState int
|
||||
|
||||
const (
|
||||
glm46ParserState_LookingForThinkingOpen glm46ParserState = iota
|
||||
glm46ParserState_ThinkingStartedEatingWhitespace
|
||||
glm46ParserState_CollectingThinking
|
||||
glm46ParserState_ThinkingDoneEatingWhitespace
|
||||
glm46ParserState_CollectingContent
|
||||
glm46ParserState_ToolStartedEatingWhitespace
|
||||
glm46ParserState_CollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
glm46ThinkingOpenTag = "<think>"
|
||||
glm46ThinkingCloseTag = "</think>"
|
||||
glm46ToolOpenTag = "<tool_call>"
|
||||
glm46ToolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
type GLM46Parser struct {
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
return tools
|
||||
}
|
||||
|
||||
type glm46Event interface {
|
||||
isGLM46Event()
|
||||
}
|
||||
|
||||
type glm46EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (glm46EventContent) isGLM46Event() {}
|
||||
|
||||
type glm46EventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (glm46EventRawToolCall) isGLM46Event() {}
|
||||
|
||||
type glm46EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (glm46EventThinkingContent) isGLM46Event() {}
|
||||
|
||||
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case glm46EventRawToolCall:
|
||||
toolCall, err := parseGLM46ToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case glm46EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case glm46EventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
// events, we naively append them together here.
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) parseEvents() []glm46Event {
|
||||
var all []glm46Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []glm46Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4.6 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
|
||||
// and transitions to the next state. Returns (nil, false) if only whitespace remains
|
||||
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
|
||||
func (p *GLM46Parser) eatLeadingWhitespaceAndTransitionTo(nextState glm46ParserState) ([]glm46Event, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false // Still only whitespace, keep waiting for more input
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true // Successfully transitioned
|
||||
}
|
||||
|
||||
// glm46SplitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
|
||||
// the content after (optionally trimmed of leading whitespace), and updates the buffer
|
||||
func glm46SplitAtTag(p *GLM46Parser, tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
|
||||
var events []glm46Event
|
||||
|
||||
switch p.state {
|
||||
case glm46ParserState_LookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, glm46ThinkingOpenTag) {
|
||||
// Found <think> opening tag
|
||||
after := strings.TrimPrefix(trimmed, glm46ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = glm46ParserState_ThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(glm46ThinkingOpenTag, trimmed) {
|
||||
// Partial opening tag seen, keep accumulating
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
// Only whitespace, keep accumulating
|
||||
return events, false
|
||||
} else {
|
||||
// No thinking tag found, skip to content collection
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
// Don't trim - we want to keep the original content
|
||||
return events, true
|
||||
}
|
||||
|
||||
case glm46ParserState_ThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingThinking)
|
||||
|
||||
case glm46ParserState_CollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, glm46ThinkingCloseTag) {
|
||||
thinking, remaining := glm46SplitAtTag(p, glm46ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = glm46ParserState_ThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, glm46ThinkingCloseTag); overlapLen > 0 {
|
||||
// Partial closing tag - withhold it along with any trailing whitespace before it
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case glm46ParserState_ThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingContent)
|
||||
|
||||
case glm46ParserState_CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), glm46ToolOpenTag) {
|
||||
before, after := glm46SplitAtTag(p, glm46ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, glm46EventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = glm46ParserState_ToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = glm46ParserState_CollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), glm46ToolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case glm46ParserState_ToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingToolContent)
|
||||
|
||||
case glm46ParserState_CollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, glm46ToolCloseTag) {
|
||||
toolContent, _ := glm46SplitAtTag(p, glm46ToolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("glm46 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, glm46EventRawToolCall{raw: toolContent})
|
||||
p.state = glm46ParserState_CollectingContent
|
||||
return events, true
|
||||
} else {
|
||||
// Keep accumulating - tool calls are not streamed
|
||||
// We just wait for the closing tag
|
||||
return events, false
|
||||
}
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// GLMToolCallXML represents the structure of a GLM-4.6 tool call for XML parsing
|
||||
type GLMToolCallXML struct {
|
||||
XMLName xml.Name `xml:"tool_call"`
|
||||
Content string `xml:",chardata"` // Function name (text nodes between tags)
|
||||
Keys []string `xml:"arg_key"` // All arg_key elements in document order
|
||||
Values []string `xml:"arg_value"` // All arg_value elements in document order
|
||||
}
|
||||
|
||||
// escapeGLM46Content escapes XML entities in text content while preserving arg_key/arg_value tags
|
||||
func escapeGLM46Content(s string) string {
|
||||
var result strings.Builder
|
||||
inTag := false
|
||||
|
||||
for i := range len(s) {
|
||||
ch := s[i]
|
||||
|
||||
if ch == '<' {
|
||||
// Check if this is a known tag
|
||||
if strings.HasPrefix(s[i:], "<arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "<arg_value>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_value>") {
|
||||
inTag = true
|
||||
}
|
||||
}
|
||||
|
||||
if inTag {
|
||||
result.WriteByte(ch)
|
||||
if ch == '>' {
|
||||
inTag = false
|
||||
}
|
||||
} else {
|
||||
// Escape special characters in text content
|
||||
switch ch {
|
||||
case '&':
|
||||
result.WriteString("&")
|
||||
case '<':
|
||||
result.WriteString("<")
|
||||
case '>':
|
||||
result.WriteString(">")
|
||||
default:
|
||||
result.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
// Escape any unescaped entities in text content
|
||||
// We need to escape text between tags, but not the tags themselves
|
||||
escaped := escapeGLM46Content(raw.raw)
|
||||
|
||||
// Wrap the content in a root element to make it valid XML
|
||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
||||
|
||||
// Parse XML into struct
|
||||
var parsed GLMToolCallXML
|
||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
}
|
||||
|
||||
// Extract and trim function name
|
||||
functionName := strings.TrimSpace(parsed.Content)
|
||||
if functionName == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
// Verify keys and values are paired correctly
|
||||
if len(parsed.Keys) != len(parsed.Values) {
|
||||
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionName {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Build arguments map by pairing keys and values
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: functionName,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range parsed.Keys {
|
||||
key := strings.TrimSpace(parsed.Keys[i])
|
||||
value := parsed.Values[i] // Don't trim here - parseValue handles it
|
||||
|
||||
// Look up parameter type
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
|
||||
// Handle anyOf by collecting all types from the union
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
paramType = append(paramType, anyOfProp.Type...)
|
||||
}
|
||||
} else {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse value with type coercion
|
||||
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
@@ -1,862 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM46ParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []glm46Event
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
steps []step
|
||||
only bool
|
||||
}{
|
||||
{
|
||||
desc: "leading whitespace before think tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " \n\t ",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "<think>thinking</think>",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "think tag with whitespace inside",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think> \n thinking content \n </think>regular content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking content"},
|
||||
glm46EventContent{content: "regular content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with leading whitespace after opening tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call> \n test \n </tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "simple thinking then content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>I am thinking</think>Now I respond",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "I am thinking"},
|
||||
glm46EventContent{content: "Now I respond"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "streamed thinking content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>hello",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "hello"}},
|
||||
},
|
||||
{
|
||||
input: " world",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: " world"}},
|
||||
},
|
||||
{
|
||||
input: "</think>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>Let me call a tool</think>here is text<tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "Let me call a tool"},
|
||||
glm46EventContent{content: "here is text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>\n</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call with content after",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think><tool_call>test</tool_call>after tool",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "after tool"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between content and tool call is trimmed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think>content\n \t <tool_call>test</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventContent{content: "content"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between tool call and content is trimmed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think><tool_call>test</tool_call>\n\t after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "think"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split thinking close tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking content</th",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking content"}},
|
||||
},
|
||||
{
|
||||
input: "ink>after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split thinking open tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <thi",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "nk>content</think>",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "split tool open tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think>content<tool",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "think"}, glm46EventContent{content: "content"}},
|
||||
},
|
||||
{
|
||||
input: "_call>inside",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "inside"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial thinking close tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>content</th",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
|
||||
},
|
||||
{
|
||||
input: "ought more",
|
||||
wantEvents: []glm46Event{glm46EventThinkingContent{content: "</thought more"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial thinking open tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <thi",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "nking is fun",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: " <thinking is fun"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tool open tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think>content\n<tool",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: " fakeout",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "\n<tool fakeout"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tool close tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call>content</tool",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: " fakeout",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
{
|
||||
input: "</tool_call>",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "content</tool fakeout"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think>content here",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content here"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple tool calls in sequence",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>think</think><tool_call>first</tool_call>between<tool_call>second</tool_call>end",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "think"},
|
||||
glm46EventRawToolCall{raw: "first"},
|
||||
glm46EventContent{content: "between"},
|
||||
glm46EventRawToolCall{raw: "second"},
|
||||
glm46EventContent{content: "end"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - direct to content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "just content here",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "just content here"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - skip to content then tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "Here's the answer:<tool_call>test</tool_call>done",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "Here's the answer:"},
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "done"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no thinking tag - whitespace preserved when no tags",
|
||||
steps: []step{
|
||||
{
|
||||
input: " \n content with leading whitespace",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: " \n content with leading whitespace"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after think close tag gets eaten",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking</think> \n\t content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after tool_call close tag gets eaten",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think></think><tool_call>test</tool_call> \n\t content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventRawToolCall{raw: "test"},
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace (single chunk)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking content ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>after",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "after"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace with newlines",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking\n\n ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content trailing whitespace emitted when more content arrives",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking ",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "more thinking",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: " more thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "</think>",
|
||||
wantEvents: []glm46Event{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking content withholds trailing whitespace before partial close tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>thinking </th",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventThinkingContent{content: "thinking"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "ink>content",
|
||||
wantEvents: []glm46Event{
|
||||
glm46EventContent{content: "content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anyOnlies := false
|
||||
for _, tc := range cases {
|
||||
if tc.only {
|
||||
anyOnlies = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
if anyOnlies && !tc.only {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := GLM46Parser{}
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGLMToolCallXMLOrderPreservation verifies that xml.Unmarshal preserves
|
||||
// document order when collecting multiple elements with the same tag name into slices.
|
||||
// This is a critical assumption for the GLM-4.6 parser's struct-based approach.
|
||||
func TestGLMToolCallXMLOrderPreservation(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
xml string
|
||||
wantKeys []string
|
||||
wantValues []string
|
||||
}{
|
||||
{
|
||||
name: "alternating keys and values",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>first</arg_key>
|
||||
<arg_value>A</arg_value>
|
||||
<arg_key>second</arg_key>
|
||||
<arg_value>B</arg_value>
|
||||
<arg_key>third</arg_key>
|
||||
<arg_value>C</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"first", "second", "third"},
|
||||
wantValues: []string{"A", "B", "C"},
|
||||
},
|
||||
{
|
||||
name: "all keys then all values",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>key1</arg_key>
|
||||
<arg_key>key2</arg_key>
|
||||
<arg_key>key3</arg_key>
|
||||
<arg_value>val1</arg_value>
|
||||
<arg_value>val2</arg_value>
|
||||
<arg_value>val3</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"key1", "key2", "key3"},
|
||||
wantValues: []string{"val1", "val2", "val3"},
|
||||
},
|
||||
{
|
||||
name: "mixed grouping",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_key>a</arg_key>
|
||||
<arg_value>1</arg_value>
|
||||
<arg_key>b</arg_key>
|
||||
<arg_key>c</arg_key>
|
||||
<arg_value>2</arg_value>
|
||||
<arg_value>3</arg_value>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"a", "b", "c"},
|
||||
wantValues: []string{"1", "2", "3"},
|
||||
},
|
||||
{
|
||||
name: "reverse order - all values then all keys",
|
||||
xml: `<tool_call>
|
||||
function_name
|
||||
<arg_value>X</arg_value>
|
||||
<arg_value>Y</arg_value>
|
||||
<arg_value>Z</arg_value>
|
||||
<arg_key>x</arg_key>
|
||||
<arg_key>y</arg_key>
|
||||
<arg_key>z</arg_key>
|
||||
</tool_call>`,
|
||||
wantKeys: []string{"x", "y", "z"},
|
||||
wantValues: []string{"X", "Y", "Z"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var parsed GLMToolCallXML
|
||||
err := xml.Unmarshal([]byte(tc.xml), &parsed)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal XML: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(parsed.Keys, tc.wantKeys) {
|
||||
t.Errorf("Keys order mismatch:\ngot: %v\nwant: %v", parsed.Keys, tc.wantKeys)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(parsed.Values, tc.wantValues) {
|
||||
t.Errorf("Values order mismatch:\ngot: %v\nwant: %v", parsed.Values, tc.wantValues)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM46ToolCallParsing(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
rawToolCall string
|
||||
tools []api.Tool
|
||||
wantToolCall api.ToolCall
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "simple tool call",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `get-current-weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>New York, NY</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-current-weather",
|
||||
Arguments: args(`{"location": "New York, NY", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with typed parameters",
|
||||
tools: []api.Tool{
|
||||
tool("calculate", map[string]api.ToolProperty{
|
||||
"x": {Type: api.PropertyType{"number"}},
|
||||
"y": {Type: api.PropertyType{"integer"}},
|
||||
"enabled": {Type: api.PropertyType{"boolean"}},
|
||||
"items": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `calculate
|
||||
<arg_key>x</arg_key>
|
||||
<arg_value>3.14</arg_value>
|
||||
<arg_key>y</arg_key>
|
||||
<arg_value>42</arg_value>
|
||||
<arg_key>enabled</arg_key>
|
||||
<arg_value>true</arg_value>
|
||||
<arg_key>items</arg_key>
|
||||
<arg_value>["a", "b", "c"]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: args(`{"enabled": true, "items": ["a", "b", "c"], "x": 3.14, "y": 42}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "function name with whitespace",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: ` get-weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Paris</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-weather",
|
||||
Arguments: args(`{"city": "Paris"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "values with special characters",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `execute-command
|
||||
<arg_key>command</arg_key>
|
||||
<arg_value>ls && echo "done"</arg_value>
|
||||
<arg_key>message</arg_key>
|
||||
<arg_value>a < b and c > d</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "execute-command",
|
||||
Arguments: args(`{"command": "ls && echo \"done\"", "message": "a < b and c > d"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unicode in function names and values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `获取天气
|
||||
<arg_key>城市</arg_key>
|
||||
<arg_value>北京</arg_value>
|
||||
<arg_key>message</arg_key>
|
||||
<arg_value>Hello! 你好! 🌟</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: args(`{"message": "Hello! 你好! 🌟", "城市": "北京"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty value",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param1</arg_key>
|
||||
<arg_value></arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param1": ""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "special chars in arg_key names",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param<1></arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
<arg_key>a&b</arg_key>
|
||||
<arg_value>value2</arg_value>
|
||||
<arg_key>x>y</arg_key>
|
||||
<arg_value>value3</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"a&b": "value2", "param<1>": "value1", "x>y": "value3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple consecutive ampersands",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>test &&&& more</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param": "test &&&& more"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed special chars together",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value><>&<>&</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"param": "<>&<>&"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "newlines and tabs in parameter values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>multiline</arg_key>
|
||||
<arg_value>line1
|
||||
indented line2
|
||||
line3</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"multiline": "line1\n\tindented line2\nline3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single and double quotes in values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>quotes</arg_key>
|
||||
<arg_value>She said "Hello's there!"</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"quotes": "She said \"Hello's there!\""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CDATA-like content that should be treated as text",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>cdata</arg_key>
|
||||
<arg_value><![CDATA[not actual cdata]]></arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"cdata": "<![CDATA[not actual cdata]]>"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all special XML entities",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>entities</arg_key>
|
||||
<arg_value><>&'"</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"entities": "<>&'""}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "order preservation with multiple parameters",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>first</arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
<arg_key>second</arg_key>
|
||||
<arg_value>value2</arg_value>
|
||||
<arg_key>third</arg_key>
|
||||
<arg_value>value3</arg_value>
|
||||
<arg_key>fourth</arg_key>
|
||||
<arg_value>value4</arg_value>
|
||||
<arg_key>fifth</arg_key>
|
||||
<arg_value>value5</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
Arguments: args(`{"fifth": "value5", "first": "value1", "fourth": "value4", "second": "value2", "third": "value3"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "order preservation with identical key names but different positions",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `test-function
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>first occurrence</arg_value>
|
||||
<arg_key>other</arg_key>
|
||||
<arg_value>middle</arg_value>
|
||||
<arg_key>param</arg_key>
|
||||
<arg_value>second occurrence</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test-function",
|
||||
// Later occurrence should overwrite earlier one
|
||||
Arguments: args(`{"other": "middle", "param": "second occurrence"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array with mixed types",
|
||||
tools: []api.Tool{
|
||||
tool("process", map[string]api.ToolProperty{
|
||||
"items": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `process
|
||||
<arg_key>items</arg_key>
|
||||
<arg_value>[1, "hello", true, null]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: args(`{"items": [1, "hello", true, null]}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
tools: []api.Tool{
|
||||
tool("test", map[string]api.ToolProperty{
|
||||
"tags": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `test
|
||||
<arg_key>tags</arg_key>
|
||||
<arg_value>[]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: args(`{"tags": []}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "anyOf array or string - with array of objects",
|
||||
tools: []api.Tool{
|
||||
tool("TodoWrite", map[string]api.ToolProperty{
|
||||
"todos": {AnyOf: []api.ToolProperty{{Type: api.PropertyType{"array"}}, {Type: api.PropertyType{"string"}}}},
|
||||
}),
|
||||
},
|
||||
// <tool_call>TodoWrite
|
||||
// <arg_key>todos</arg_key>
|
||||
// <arg_value>[{"content": "Set up HTML file and basic structure", "id": "1", "priority": "high", "status": "pending"}, {"content": "Create 3D scene with Three.js", "id": "2", "priority": "high", "status": "pending"}, {"content": "Implement terrain generation with blocks", "id": "3", "priority": "high", "status": "pending"}, {"content": "Add player controls (movement, camera)", "id": "4", "priority": "high", "status": "pending"}, {"content": "Implement block placement/destruction", "id": "5", "priority": "medium", "status": "pending"}, {"content": "Add lighting and textures", "id": "6", "priority": "medium", "status": "pending"}, {"content": "Test and optimize performance", "id": "7", "priority": "low", "status": "pending"}]</arg_value>
|
||||
// </tool_call>
|
||||
rawToolCall: `TodoWrite
|
||||
<arg_key>todos</arg_key>
|
||||
<arg_value>[{"content": "task 1", "status": "pending", "priority": "high", "id": "1"}, {"content": "task 2", "status": "completed", "priority": "low", "id": "2"}]</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "TodoWrite",
|
||||
Arguments: args(`{"todos": [{"content": "task 1", "id": "1", "priority": "high", "status": "pending"}, {"content": "task 2", "id": "2", "priority": "low", "status": "completed"}]}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "anyOf array or string - with plain string",
|
||||
tools: []api.Tool{
|
||||
tool("TodoWrite", map[string]api.ToolProperty{
|
||||
"todos": {Type: api.PropertyType{"array", "string"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `TodoWrite
|
||||
<arg_key>todos</arg_key>
|
||||
<arg_value>Error: could not load todos</arg_value>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "TodoWrite",
|
||||
Arguments: args(`{"todos": "Error: could not load todos"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gotToolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: tc.rawToolCall}, tc.tools)
|
||||
if err != nil {
|
||||
t.Errorf("case %d (%s): %v", i, tc.name, err)
|
||||
}
|
||||
if !toolCallEqual(gotToolCall, tc.wantToolCall) {
|
||||
t.Errorf("case %d (%s): got tool call %#v, want %#v", i, tc.name, gotToolCall, tc.wantToolCall)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import "github.com/ollama/ollama/api"
|
||||
|
||||
// GLM47Parser extends GLM46Parser with thinking-aware initialization.
|
||||
// GLM-4.7's prompt ends with <think> when thinking is enabled, so the parser
|
||||
// must start in CollectingThinking state (the model outputs thinking content directly).
|
||||
type GLM47Parser struct {
|
||||
GLM46Parser
|
||||
}
|
||||
|
||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
p.state = glm46ParserState_CollectingThinking
|
||||
}
|
||||
return tools
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM47ParserAdd(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init([]api.Tool{
|
||||
tool("calculate", map[string]api.ToolProperty{
|
||||
"count": {Type: api.PropertyType{"integer"}},
|
||||
"enabled": {Type: api.PropertyType{"boolean"}},
|
||||
}),
|
||||
}, nil, nil)
|
||||
|
||||
// When thinking is enabled (thinkValue nil), the prompt ends with <think>,
|
||||
// so the model output does NOT include the opening <think> tag.
|
||||
content, thinking, calls, err := parser.Add("plan</think>Answer<tool_call>calculate<arg_key>count</arg_key><arg_value>3</arg_value><arg_key>enabled</arg_key><arg_value>true</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "plan" {
|
||||
t.Fatalf("expected thinking 'plan', got %q", thinking)
|
||||
}
|
||||
if content != "Answer" {
|
||||
t.Fatalf("expected content 'Answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
expectedArgs := args(`{"count": 3, "enabled": true}`)
|
||||
if !toolCallEqual(api.ToolCall{Function: api.ToolCallFunction{Arguments: calls[0].Function.Arguments}}, api.ToolCall{Function: api.ToolCallFunction{Arguments: expectedArgs}}) {
|
||||
t.Fatalf("expected args %#v, got %#v", expectedArgs.ToMap(), calls[0].Function.Arguments.ToMap())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserNoThinkingContent(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
// When thinking is enabled but model has no thinking to output,
|
||||
// it should output </think> immediately followed by content.
|
||||
content, thinking, calls, err := parser.Add("</think>Plain answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Plain answer" {
|
||||
t.Fatalf("expected content 'Plain answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserThinkingDisabled(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
// When thinking is disabled, parser stays in LookingForThinkingOpen state
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
// Model outputs plain content (prompt ended with </think>)
|
||||
content, thinking, calls, err := parser.Add("Plain answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Plain answer" {
|
||||
t.Fatalf("expected content 'Plain answer', got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
||||
toolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: `exec
|
||||
<arg_key>expr</arg_key>
|
||||
<arg_value>a < b && c > d</arg_value>`}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
expected := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: args(`{"expr": "a < b && c > d"}`),
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(toolCall, expected) {
|
||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||
}
|
||||
}
|
||||
@@ -68,8 +68,6 @@ func ParserForName(name string) Parser {
|
||||
return &Nemotron3NanoParser{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaParser{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Parser{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -96,11 +96,3 @@ func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func args(s string) api.ToolCallFunctionArguments {
|
||||
var result api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in args(): " + err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type GLM46Renderer struct{}
|
||||
|
||||
func (r *GLM46Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
var lastUserIndex int
|
||||
for i, message := range messages {
|
||||
if message.Role == "user" {
|
||||
lastUserIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(string(d) + "\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}\n")
|
||||
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
|
||||
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
|
||||
sb.WriteString("...\n")
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
|
||||
sb.WriteString("/nothink")
|
||||
}
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if i > lastUserIndex {
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("\n<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("\n<think></think>")
|
||||
}
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString("\n" + message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
|
||||
for key, value := range toolCall.Function.Arguments.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
|
||||
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
|
||||
}
|
||||
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt
|
||||
sb.WriteString("<|assistant|>")
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
sb.WriteString("\n<think></think>\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM46Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
skip string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
You are a helpful assistant.<|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with user assistant user",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is the capital of France?"},
|
||||
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
|
||||
{Role: "user", Content: "Fantastic!"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
What is the capital of France?<|assistant|>
|
||||
The capital of France is Paris.<|user|>
|
||||
Fantastic!<|assistant|>`,
|
||||
},
|
||||
{
|
||||
skip: "tool call ordering not guaranteed yet",
|
||||
name: "tools",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
skip: "tool call ordering not guaranteed yet",
|
||||
name: "tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo, Japan", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Japan", "unit": "fahrenheit"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>
|
||||
<think></think>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Tokyo, Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call><|observation|>
|
||||
<tool_response>
|
||||
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"temperature": 68, "weather": "sunny", "humidity": 75}
|
||||
</tool_response><|assistant|>
|
||||
<think></think>
|
||||
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think true",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think false",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?/nothink<|assistant|>
|
||||
<think></think>
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.skip != "" {
|
||||
t.Skip(tt.skip)
|
||||
}
|
||||
renderer := &GLM46Renderer{}
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
t.Logf("Got:\n%s", rendered)
|
||||
t.Logf("Expected:\n%s", tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,170 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// GLM47Renderer renders messages for GLM-4.7 models.
|
||||
//
|
||||
// GLM-4.7 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
|
||||
//
|
||||
// 1. INTERLEAVED THINKING
|
||||
// The model thinks between tool calls and after receiving tool results.
|
||||
// This enables complex step-by-step reasoning: interpreting each tool output
|
||||
// before deciding what to do next. Thinking blocks are preserved and returned
|
||||
// with tool results to maintain reasoning continuity.
|
||||
//
|
||||
// 2. PRESERVED THINKING
|
||||
// The model retains reasoning content from previous assistant turns in context.
|
||||
// This preserves reasoning continuity across multi-turn conversations. The
|
||||
// upstream API has a "clear_thinking" parameter to control this:
|
||||
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
|
||||
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
|
||||
//
|
||||
// 3. TURN-LEVEL THINKING
|
||||
// Controls whether the model should reason on each turn. The upstream API
|
||||
// uses "enable_thinking" parameter:
|
||||
// - enable_thinking=true: outputs <think> to start reasoning
|
||||
// - enable_thinking=false: outputs </think> to skip reasoning
|
||||
//
|
||||
// OLLAMA DEFAULTS:
|
||||
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
|
||||
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
|
||||
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
|
||||
// - Users can disable thinking per-turn via thinkValue=false
|
||||
type GLM47Renderer struct{}
|
||||
|
||||
func (r *GLM47Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(formatGLM47ToolJSON(d))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
||||
}
|
||||
|
||||
think := true
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
think = false
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>")
|
||||
sb.WriteString(message.Content)
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<tool_call>" + toolCall.Function.Name)
|
||||
sb.WriteString(renderGLM47ToolArguments(toolCall.Function.Arguments))
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("<tool_response>")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|assistant|>")
|
||||
if think {
|
||||
sb.WriteString("<think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func renderGLM47ToolArguments(args api.ToolCallFunctionArguments) string {
|
||||
var sb strings.Builder
|
||||
for key, value := range args.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func formatGLM47ToolJSON(raw []byte) string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(raw) + len(raw)/10)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := range raw {
|
||||
ch := raw[i]
|
||||
sb.WriteByte(ch)
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
inString = true
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ':' || ch == ',' {
|
||||
sb.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM47Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic user message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Hello<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: "[gMASK]<sop><|user|>Hello<|assistant|></think>",
|
||||
},
|
||||
{
|
||||
name: "system and user",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>You are helpful.<|user|>Hello<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello there"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Hi<|assistant|></think>Hello there<|user|>How are you?<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "assistant with reasoning_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Answer with reasoning."},
|
||||
{Role: "assistant", Thinking: "Plan.", Content: "Done."},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Answer with reasoning.<|assistant|><think>Plan.</think>Done.<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "tool call with empty content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "tool call with content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
{Role: "assistant", Content: "It is 22C."},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think>Let me check<tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|></think>It is 22C.<|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls and responses",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Compare weather"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Tokyo"}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args(`{"location": "Paris"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature":22}`},
|
||||
{Role: "tool", Content: `{"temperature":18}`},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: propsMap(`{"location": {"type": "string"}}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Compare weather<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><tool_call>get_weather<arg_key>location</arg_key><arg_value>Paris</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><tool_response>{\"temperature\":18}</tool_response><|assistant|><think>",
|
||||
},
|
||||
{
|
||||
name: "preserved thinking in multi-turn",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Think step by step"},
|
||||
{Role: "assistant", Thinking: "Let me think...", Content: "Here's my answer."},
|
||||
{Role: "user", Content: "Continue"},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>Think step by step<|assistant|><think>Let me think...</think>Here's my answer.<|user|>Continue<|assistant|><think>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
renderer := &GLM47Renderer{}
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
t.Logf("Got:\n%s", rendered)
|
||||
t.Logf("Expected:\n%s", tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -80,8 +80,6 @@ func rendererForName(name string) Renderer {
|
||||
return &Nemotron3NanoRenderer{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaRenderer{}
|
||||
case "glm-4.7":
|
||||
return &GLM47Renderer{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,26 +1,6 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func args(s string) api.ToolCallFunctionArguments {
|
||||
var result api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in args(): " + err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func propsMap(s string) *api.ToolPropertiesMap {
|
||||
var result api.ToolPropertiesMap
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
panic("invalid JSON in propsMap(): " + err.Error())
|
||||
}
|
||||
return &result
|
||||
}
|
||||
import "github.com/ollama/ollama/api"
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
|
||||
@@ -220,6 +220,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle image generation models
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
s.handleImageGenerate(c, req, name.String(), checkpointStart)
|
||||
return
|
||||
}
|
||||
|
||||
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
||||
return
|
||||
@@ -315,7 +321,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner if unload is requested (empty prompt, keep alive is 0)
|
||||
// expire the runner
|
||||
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
|
||||
@@ -329,12 +335,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle image generation models
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
s.handleImageGenerate(c, req, name.String(), checkpointStart)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||
return
|
||||
|
||||
@@ -2101,95 +2101,3 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateUnload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var loadFnCalled bool
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mockRunner{}),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
loadFnCalled = true
|
||||
req.successCh <- &runnerRef{llama: &mockRunner{}}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("unload with empty prompt and keepalive 0", func(t *testing.T) {
|
||||
loadFnCalled = false
|
||||
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "",
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.GenerateResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if resp.DoneReason != "unload" {
|
||||
t.Errorf("expected done_reason 'unload', got %q", resp.DoneReason)
|
||||
}
|
||||
|
||||
if !resp.Done {
|
||||
t.Error("expected done to be true")
|
||||
}
|
||||
|
||||
if loadFnCalled {
|
||||
t.Error("expected model NOT to be loaded for unload request, but loadFn was called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -571,7 +571,6 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
totalSize: server.TotalSize(),
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -281,19 +280,3 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// ExtractModelfileConfig extracts template, system, and license from a parsed Modelfile.
|
||||
func ExtractModelfileConfig(modelfile *parser.Modelfile) *ModelfileConfig {
|
||||
mfConfig := &ModelfileConfig{}
|
||||
for _, cmd := range modelfile.Commands {
|
||||
switch cmd.Name {
|
||||
case "template":
|
||||
mfConfig.Template = cmd.Args
|
||||
case "system":
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
}
|
||||
}
|
||||
return mfConfig
|
||||
}
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
// Supported quantization types: "fp8" (affine 8-bit)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
// Lazy init MLX when needed for quantization
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("MLX initialization failed: %w", err)
|
||||
}
|
||||
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||
@@ -54,9 +59,6 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "fp4":
|
||||
// affine mode: group_size=32, bits=4
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
|
||||
case "fp8":
|
||||
// affine mode: group_size=32, bits=8
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
|
||||
|
||||
@@ -20,10 +20,10 @@ import (
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "fp4", "fp8":
|
||||
case "", "fp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
|
||||
@@ -7,17 +7,12 @@ import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
@@ -51,8 +46,8 @@ func main() {
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 0, "Image width (0 = auto from input or 1024)")
|
||||
height := flag.Int("height", 0, "Image height (0 = auto from input or 1024)")
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
@@ -66,7 +61,6 @@ func main() {
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
@@ -128,44 +122,6 @@ func main() {
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *flux2Flag:
|
||||
m := &flux2.Model{}
|
||||
if loadErr := m.Load(*modelPath); loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// Load input images with EXIF orientation correction
|
||||
var loadedImages []image.Image
|
||||
for _, path := range inputImages {
|
||||
img, loadErr := loadImageWithEXIF(path)
|
||||
if loadErr != nil {
|
||||
log.Fatalf("Failed to load image %s: %v", path, loadErr)
|
||||
}
|
||||
loadedImages = append(loadedImages, img)
|
||||
}
|
||||
// When input images provided and user didn't override dimensions, use 0 to match input
|
||||
fluxWidth := int32(*width)
|
||||
fluxHeight := int32(*height)
|
||||
if len(loadedImages) > 0 && *width == 0 && *height == 0 {
|
||||
// Both unset, will auto-detect from input
|
||||
} else if len(loadedImages) > 0 && *width == 0 {
|
||||
fluxWidth = 0 // Compute from height + aspect ratio
|
||||
} else if len(loadedImages) > 0 && *height == 0 {
|
||||
fluxHeight = 0 // Compute from width + aspect ratio
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &flux2.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: fluxWidth,
|
||||
Height: fluxHeight,
|
||||
Steps: *steps,
|
||||
GuidanceScale: float32(*cfgScale),
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
InputImages: loadedImages,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImage:
|
||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
@@ -320,8 +276,6 @@ func detectModelKind(modelPath string) (string, error) {
|
||||
switch index.ClassName {
|
||||
case "FluxPipeline", "ZImagePipeline":
|
||||
return "zimage", nil
|
||||
case "Flux2KleinPipeline":
|
||||
return "flux2", nil
|
||||
}
|
||||
}
|
||||
return "zimage", nil
|
||||
@@ -342,12 +296,3 @@ func detectModelKind(modelPath string) (string, error) {
|
||||
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
|
||||
// loadImageWithEXIF loads an image from a file path with EXIF orientation correction.
|
||||
func loadImageWithEXIF(path string) (image.Image, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
return imagegen.DecodeImage(data)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -109,160 +108,3 @@ func clampF(v, min, max float32) float32 {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// DecodeImage decodes image bytes with EXIF orientation applied.
|
||||
func DecodeImage(data []byte) (image.Image, error) {
|
||||
orientation := readJPEGOrientation(data)
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return applyOrientation(img, orientation), nil
|
||||
}
|
||||
|
||||
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
|
||||
// Returns 1 (normal) for non-JPEG or if orientation not found.
|
||||
func readJPEGOrientation(data []byte) int {
|
||||
if len(data) < 2 || data[0] != 0xFF || data[1] != 0xD8 {
|
||||
return 1 // Not JPEG
|
||||
}
|
||||
|
||||
r := bytes.NewReader(data[2:])
|
||||
for {
|
||||
var marker [2]byte
|
||||
if _, err := r.Read(marker[:]); err != nil || marker[0] != 0xFF {
|
||||
return 1
|
||||
}
|
||||
|
||||
if marker[1] == 0xE1 { // APP1 (EXIF)
|
||||
var lenBytes [2]byte
|
||||
if _, err := r.Read(lenBytes[:]); err != nil {
|
||||
return 1
|
||||
}
|
||||
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
|
||||
if segLen < 14 {
|
||||
r.Seek(int64(segLen), 1)
|
||||
continue
|
||||
}
|
||||
seg := make([]byte, segLen)
|
||||
if _, err := r.Read(seg); err != nil {
|
||||
return 1
|
||||
}
|
||||
if string(seg[:4]) == "Exif" && seg[4] == 0 && seg[5] == 0 {
|
||||
return parseTIFFOrientation(seg[6:])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if marker[1] == 0xD9 || marker[1] == 0xDA {
|
||||
return 1 // EOI or SOS
|
||||
}
|
||||
if marker[1] >= 0xD0 && marker[1] <= 0xD7 {
|
||||
continue // RST markers
|
||||
}
|
||||
|
||||
var lenBytes [2]byte
|
||||
if _, err := r.Read(lenBytes[:]); err != nil {
|
||||
return 1
|
||||
}
|
||||
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
|
||||
if segLen > 0 {
|
||||
r.Seek(int64(segLen), 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseTIFFOrientation(tiff []byte) int {
|
||||
if len(tiff) < 8 {
|
||||
return 1
|
||||
}
|
||||
|
||||
var big bool
|
||||
switch string(tiff[:2]) {
|
||||
case "MM":
|
||||
big = true
|
||||
case "II":
|
||||
big = false
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
|
||||
u16 := func(b []byte) uint16 {
|
||||
if big {
|
||||
return uint16(b[0])<<8 | uint16(b[1])
|
||||
}
|
||||
return uint16(b[1])<<8 | uint16(b[0])
|
||||
}
|
||||
u32 := func(b []byte) uint32 {
|
||||
if big {
|
||||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||
}
|
||||
return uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
|
||||
}
|
||||
|
||||
if u16(tiff[2:4]) != 42 {
|
||||
return 1
|
||||
}
|
||||
|
||||
ifdOffset := u32(tiff[4:8])
|
||||
if int(ifdOffset)+2 > len(tiff) {
|
||||
return 1
|
||||
}
|
||||
|
||||
numEntries := u16(tiff[ifdOffset : ifdOffset+2])
|
||||
for i := range int(numEntries) {
|
||||
offset := ifdOffset + 2 + uint32(i)*12
|
||||
if int(offset)+12 > len(tiff) {
|
||||
break
|
||||
}
|
||||
if u16(tiff[offset:offset+2]) == 0x0112 { // Orientation tag
|
||||
o := int(u16(tiff[offset+8 : offset+10]))
|
||||
if o >= 1 && o <= 8 {
|
||||
return o
|
||||
}
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func applyOrientation(img image.Image, orientation int) image.Image {
|
||||
if orientation <= 1 || orientation > 8 {
|
||||
return img
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
|
||||
outW, outH := w, h
|
||||
if orientation >= 5 {
|
||||
outW, outH = h, w
|
||||
}
|
||||
|
||||
out := image.NewRGBA(image.Rect(0, 0, outW, outH))
|
||||
for y := range h {
|
||||
for x := range w {
|
||||
var dx, dy int
|
||||
switch orientation {
|
||||
case 2:
|
||||
dx, dy = w-1-x, y
|
||||
case 3:
|
||||
dx, dy = w-1-x, h-1-y
|
||||
case 4:
|
||||
dx, dy = x, h-1-y
|
||||
case 5:
|
||||
dx, dy = y, x
|
||||
case 6:
|
||||
dx, dy = h-1-y, x
|
||||
case 7:
|
||||
dx, dy = h-1-y, w-1-x
|
||||
case 8:
|
||||
dx, dy = y, w-1-x
|
||||
}
|
||||
out.Set(dx, dy, img.At(x+bounds.Min.X, y+bounds.Min.Y))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -24,8 +24,9 @@ var SupportedBackends = []string{"metal", "cuda", "cpu"}
|
||||
|
||||
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
|
||||
var modelVRAMEstimates = map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 20 * GB, // ~20GB for Flux
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
|
||||
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
|
||||
}
|
||||
|
||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||
@@ -71,38 +72,26 @@ func ResolveModelName(modelName string) string {
|
||||
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
|
||||
// Returns a conservative default of 21GB if the model type cannot be determined.
|
||||
func EstimateVRAM(modelName string) uint64 {
|
||||
className := DetectModelType(modelName)
|
||||
if estimate, ok := modelVRAMEstimates[className]; ok {
|
||||
return estimate
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// DetectModelType reads model_index.json and returns the model type.
|
||||
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
||||
// Returns empty string if detection fails.
|
||||
func DetectModelType(modelName string) string {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return ""
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return ""
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// Parse just the class name
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ClassName string `json:"_class_name"`
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err != nil {
|
||||
return ""
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// Prefer architecture (Ollama format), fall back to _class_name (diffusers)
|
||||
if index.Architecture != "" {
|
||||
return index.Architecture
|
||||
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
|
||||
return estimate
|
||||
}
|
||||
return index.ClassName
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
@@ -72,8 +72,9 @@ func TestCheckMemoryRequirements(t *testing.T) {
|
||||
func TestModelVRAMEstimates(t *testing.T) {
|
||||
// Verify the VRAM estimates map has expected entries
|
||||
expected := map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 20 * GB,
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 21 * GB,
|
||||
"QwenImagePipeline": 80 * GB,
|
||||
}
|
||||
|
||||
for name, expectedVRAM := range expected {
|
||||
|
||||
@@ -1137,27 +1137,6 @@ func RMSNormNoWeight(x *Array, eps float32) *Array {
|
||||
return RMSNorm(x, ones, eps)
|
||||
}
|
||||
|
||||
// LayerNorm applies layer normalization without learnable params
|
||||
// (x - mean) / sqrt(var + eps)
|
||||
func LayerNorm(x *Array, eps float32) *Array {
|
||||
return LayerNormWithWeightBias(x, nil, nil, eps)
|
||||
}
|
||||
|
||||
// LayerNormWithWeightBias computes layer normalization using mlx.fast
|
||||
// weight and bias can be nil for elementwise_affine=False
|
||||
func LayerNormWithWeightBias(x, weight, bias *Array, eps float32) *Array {
|
||||
res := C.mlx_array_new()
|
||||
var wc, bc C.mlx_array
|
||||
if weight != nil {
|
||||
wc = weight.c
|
||||
}
|
||||
if bias != nil {
|
||||
bc = bias.c
|
||||
}
|
||||
C.mlx_fast_layer_norm(&res, x.c, wc, bc, C.float(eps), C.default_stream())
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
// RoPE applies rotary position embeddings using mlx.fast
|
||||
func RoPE(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
res := C.mlx_array_new()
|
||||
|
||||
@@ -1,539 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package flux2 implements the FLUX.2 Klein diffusion transformer model.
|
||||
// Klein is a 4B parameter distilled model that supports sub-second inference.
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen3"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 4 for Klein)
|
||||
GuidanceScale float32 // Guidance scale (default: 1.0, Klein doesn't need CFG)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
InputImages []image.Image // Reference images for image conditioning (already loaded)
|
||||
}
|
||||
|
||||
// Model represents a FLUX.2 Klein model.
|
||||
type Model struct {
|
||||
ModelName string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *qwen3.TextEncoder
|
||||
Transformer *Flux2Transformer2DModel
|
||||
VAE *AutoencoderKLFlux2
|
||||
SchedulerConfig *SchedulerConfig
|
||||
}
|
||||
|
||||
// TextEncoderLayerIndices are the layers from which to extract text embeddings.
|
||||
// Diffusers uses hidden_states[9, 18, 27]. In Python, hidden_states[0] is the embedding
|
||||
// output before any layers, so hidden_states[9] = after layer 8 (0-indexed).
|
||||
// Go's ForwardWithLayerOutputs captures after layer i runs, so we use [8, 17, 26].
|
||||
var TextEncoderLayerIndices = []int{8, 17, 26}
|
||||
|
||||
// Load loads the FLUX.2 Klein model from ollama blob storage.
|
||||
func (m *Model) Load(modelName string) error {
|
||||
fmt.Printf("Loading FLUX.2 Klein model from manifest: %s...\n", modelName)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
|
||||
tokConfig := &tokenizer.TokenizerConfig{}
|
||||
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = data
|
||||
}
|
||||
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = data
|
||||
}
|
||||
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
|
||||
tokConfig.SpecialTokensMapJSON = data
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder
|
||||
m.TextEncoder = &qwen3.TextEncoder{}
|
||||
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Flux2Transformer2DModel{}
|
||||
if err := m.Transformer.Load(manifest); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
|
||||
// Load VAE
|
||||
m.VAE = &AutoencoderKLFlux2{}
|
||||
if err := m.VAE.Load(manifest); err != nil {
|
||||
return fmt.Errorf("VAE: %w", err)
|
||||
}
|
||||
|
||||
// Evaluate all weights in a single batch (reduces GPU sync overhead)
|
||||
fmt.Print(" Evaluating weights... ")
|
||||
allWeights := mlx.Collect(m.TextEncoder)
|
||||
allWeights = append(allWeights, mlx.Collect(m.Transformer)...)
|
||||
allWeights = append(allWeights, mlx.Collect(m.VAE)...)
|
||||
mlx.Eval(allWeights...)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load scheduler config
|
||||
m.SchedulerConfig = DefaultSchedulerConfig()
|
||||
if schedData, err := manifest.ReadConfig("scheduler/scheduler_config.json"); err == nil {
|
||||
if err := json.Unmarshal(schedData, m.SchedulerConfig); err != nil {
|
||||
fmt.Printf(" Warning: failed to parse scheduler config: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements runner.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
|
||||
const MaxOutputPixels = 2048 * 2048
|
||||
|
||||
// MaxRefPixels is the maximum resolution for reference images (smaller to reduce attention memory)
|
||||
const MaxRefPixels = 728 * 728
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Enable MLX compilation for fused kernels
|
||||
mlx.EnableCompile()
|
||||
|
||||
// Apply defaults
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 4 // Klein default: 4 steps for distilled model
|
||||
}
|
||||
if cfg.GuidanceScale <= 0 {
|
||||
cfg.GuidanceScale = 1.0 // Klein doesn't need guidance
|
||||
}
|
||||
|
||||
// Determine output dimensions
|
||||
if len(cfg.InputImages) > 0 {
|
||||
// With input images, compute missing dimension from aspect ratio
|
||||
// Images are already EXIF-rotated by the caller
|
||||
bounds := cfg.InputImages[0].Bounds()
|
||||
imgW, imgH := bounds.Dx(), bounds.Dy()
|
||||
aspectRatio := float64(imgH) / float64(imgW)
|
||||
if cfg.Width > 0 && cfg.Height <= 0 {
|
||||
// Width specified, compute height
|
||||
cfg.Height = int32(math.Round(float64(cfg.Width)*aspectRatio/16) * 16)
|
||||
} else if cfg.Height > 0 && cfg.Width <= 0 {
|
||||
// Height specified, compute width
|
||||
cfg.Width = int32(math.Round(float64(cfg.Height)/aspectRatio/16) * 16)
|
||||
} else if cfg.Width <= 0 && cfg.Height <= 0 {
|
||||
// Neither specified, use input dimensions
|
||||
cfg.Width = int32(imgW)
|
||||
cfg.Height = int32(imgH)
|
||||
}
|
||||
}
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
|
||||
// Cap to max pixels, preserve aspect ratio, round to multiple of 16
|
||||
pixels := int(cfg.Width) * int(cfg.Height)
|
||||
if pixels > MaxOutputPixels {
|
||||
scale := math.Sqrt(float64(MaxOutputPixels) / float64(pixels))
|
||||
cfg.Width = int32(math.Round(float64(cfg.Width) * scale / 16) * 16)
|
||||
cfg.Height = int32(math.Round(float64(cfg.Height) * scale / 16) * 16)
|
||||
}
|
||||
cfg.Height = int32((cfg.Height + 8) / 16 * 16) // round to nearest 16
|
||||
cfg.Width = int32((cfg.Width + 8) / 16 * 16)
|
||||
fmt.Printf(" Output: %dx%d\n", cfg.Width, cfg.Height)
|
||||
|
||||
tcfg := m.Transformer.TransformerConfig
|
||||
patchSize := m.VAE.Config.PatchSize
|
||||
|
||||
// Latent dimensions: image / 8 (VAE downscale) / patch_size
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
patchH := latentH / patchSize[0]
|
||||
patchW := latentW / patchSize[1]
|
||||
imgSeqLen := patchH * patchW
|
||||
|
||||
// Text encoding with multi-layer extraction (no padding, use true sequence length)
|
||||
fmt.Print(" Encoding prompt... ")
|
||||
promptEmbeds, textLen := m.TextEncoder.EncodePromptWithLayers(m.Tokenizer, cfg.Prompt, 512, TextEncoderLayerIndices, false)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Encode reference images if provided
|
||||
var refTokens *ImageCondTokens
|
||||
var refHeights, refWidths []int32
|
||||
if len(cfg.InputImages) > 0 {
|
||||
fmt.Printf(" Encoding %d reference image(s):\n", len(cfg.InputImages))
|
||||
|
||||
var err error
|
||||
refTokens, err = m.EncodeImageRefs(cfg.InputImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode reference images: %w", err)
|
||||
}
|
||||
|
||||
// Extract heights/widths for RoPE computation (same limits as EncodeImageRefs)
|
||||
limitPixels := MaxRefPixels
|
||||
if len(cfg.InputImages) > 1 {
|
||||
limitPixels = MaxRefPixels / 2
|
||||
}
|
||||
for _, img := range cfg.InputImages {
|
||||
_, w, h := PrepareImage(img, limitPixels)
|
||||
refHeights = append(refHeights, int32(h/16))
|
||||
refWidths = append(refWidths, int32(w/16))
|
||||
}
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(m.SchedulerConfig)
|
||||
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(imgSeqLen, cfg.Steps))
|
||||
|
||||
// Init latents in packed form [B, C*4, H/2, W/2] like diffusers
|
||||
// diffusers creates noise in [B, 128, 64, 64] and packs to [B, 4096, 128]
|
||||
latentChannels := m.VAE.Config.LatentChannels
|
||||
packedChannels := latentChannels * 4 // 32 * 4 = 128
|
||||
latents := scheduler.InitNoise([]int32{1, packedChannels, patchH, patchW}, cfg.Seed)
|
||||
|
||||
// Pack latents (transpose): [B, C, H, W] -> [B, H*W, C]
|
||||
// This matches diffusers' _pack_latents
|
||||
patches := packLatents(latents)
|
||||
noiseSeqLen := patches.Shape()[1]
|
||||
|
||||
// RoPE cache - includes reference images if present
|
||||
rope := PrepareRoPECache(textLen, patchH, patchW, tcfg.AxesDimsRoPE, tcfg.RopeTheta, refHeights, refWidths, ImageRefScale)
|
||||
|
||||
// Cleanup setup arrays when done
|
||||
defer func() {
|
||||
rope.Cos.Free()
|
||||
rope.Sin.Free()
|
||||
promptEmbeds.Free()
|
||||
if refTokens != nil {
|
||||
refTokens.Tokens.Free()
|
||||
}
|
||||
}()
|
||||
|
||||
// Pre-compute all timesteps before the loop to avoid per-step tensor creation
|
||||
timesteps := make([]*mlx.Array, cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
tCurr := scheduler.Timesteps[i] / float32(m.SchedulerConfig.NumTrainTimesteps)
|
||||
timesteps[i] = mlx.ToBFloat16(mlx.NewArray([]float32{tCurr}, []int32{1}))
|
||||
}
|
||||
|
||||
// Evaluate setup arrays
|
||||
fmt.Print(" Evaluating setup... ")
|
||||
setupStart := time.Now()
|
||||
toEval := []*mlx.Array{promptEmbeds, patches, rope.Cos, rope.Sin}
|
||||
toEval = append(toEval, timesteps...)
|
||||
if refTokens != nil {
|
||||
toEval = append(toEval, refTokens.Tokens)
|
||||
}
|
||||
mlx.Eval(toEval...)
|
||||
mlx.MetalResetPeakMemory() // Reset peak to measure generation separately
|
||||
fmt.Printf("✓ (%.2fs, %.1f GB)\n", time.Since(setupStart).Seconds(),
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(0, cfg.Steps)
|
||||
}
|
||||
|
||||
loopStart := time.Now()
|
||||
stepStart := time.Now()
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
// Check for cancellation
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// GPU capture on step 2 if requested
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStartCapture(cfg.CapturePath)
|
||||
}
|
||||
|
||||
timestep := timesteps[i]
|
||||
|
||||
// Prepare input - concatenate noise patches with reference tokens if present
|
||||
imgInput := patches
|
||||
if refTokens != nil {
|
||||
imgInput = mlx.Concatenate([]*mlx.Array{patches, refTokens.Tokens}, 1)
|
||||
}
|
||||
|
||||
// Transformer forward pass
|
||||
output := m.Transformer.Forward(imgInput, promptEmbeds, timestep, rope)
|
||||
|
||||
// If we concatenated reference tokens, slice to only get noise portion
|
||||
if refTokens != nil {
|
||||
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, noiseSeqLen, output.Shape()[2]})
|
||||
}
|
||||
|
||||
// Scheduler step (keep reference to old patches for the computation graph)
|
||||
newPatches := scheduler.Step(output, patches, i)
|
||||
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStopCapture()
|
||||
}
|
||||
|
||||
mlx.Eval(newPatches)
|
||||
patches = newPatches
|
||||
|
||||
elapsed := time.Since(stepStart).Seconds()
|
||||
peakGB := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
if i == 0 {
|
||||
fmt.Printf(" step %d: %.2fs (JIT warmup), peak %.1f GB\n", i+1, elapsed, peakGB)
|
||||
} else {
|
||||
fmt.Printf(" step %d: %.2fs, peak %.1f GB\n", i+1, elapsed, peakGB)
|
||||
}
|
||||
stepStart = time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
}
|
||||
|
||||
loopTime := time.Since(loopStart).Seconds()
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Denoised %d steps in %.2fs (%.2fs/step), peak %.1f GB\n",
|
||||
cfg.Steps, loopTime, loopTime/float64(cfg.Steps), peakMem)
|
||||
|
||||
// Free timesteps now that denoising is done
|
||||
for _, ts := range timesteps {
|
||||
ts.Free()
|
||||
}
|
||||
|
||||
// VAE decode with tiling for larger images
|
||||
fmt.Print(" Decoding VAE... ")
|
||||
vaeStart := time.Now()
|
||||
// Enable tiling for images > 512x512 (latent > 64x64)
|
||||
// VAE attention is O(n²) on latent pixels, tiling reduces memory significantly
|
||||
if patchH*2 > 64 || patchW*2 > 64 {
|
||||
m.VAE.Tiling = DefaultTilingConfig()
|
||||
}
|
||||
decoded := m.VAE.Decode(patches, patchH, patchW)
|
||||
mlx.Eval(decoded)
|
||||
|
||||
// Free patches now that decode is done
|
||||
patches.Free()
|
||||
|
||||
fmt.Printf("✓ (%.2fs, peak %.1f GB)\n", time.Since(vaeStart).Seconds(),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// packLatents converts [B, C, H, W] to [B, H*W, C] (matches diffusers _pack_latents)
|
||||
func packLatents(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
// [B, C, H, W] -> [B, C, H*W] -> [B, H*W, C]
|
||||
x = mlx.Reshape(x, B, C, H*W)
|
||||
return mlx.Transpose(x, 0, 2, 1)
|
||||
}
|
||||
|
||||
// LoadPersistent loads the model and keeps it in memory for repeated use.
|
||||
func LoadPersistent(modelName string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ImageRefScale is the time coordinate offset between reference images (matches diffusers scale=10)
|
||||
const ImageRefScale = 10
|
||||
|
||||
// PrepareImage resizes and crops an image to be a multiple of 16, with optional pixel limit.
|
||||
// Returns the processed image and its dimensions.
|
||||
func PrepareImage(img image.Image, limitPixels int) (image.Image, int, int) {
|
||||
bounds := img.Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
|
||||
// Cap pixels if needed (like diffusers cap_pixels)
|
||||
if limitPixels > 0 && w*h > limitPixels {
|
||||
scale := math.Sqrt(float64(limitPixels) / float64(w*h))
|
||||
w = int(float64(w) * scale)
|
||||
h = int(float64(h) * scale)
|
||||
}
|
||||
|
||||
// Round down to multiple of 16
|
||||
w = (w / 16) * 16
|
||||
h = (h / 16) * 16
|
||||
|
||||
if w < 16 {
|
||||
w = 16
|
||||
}
|
||||
if h < 16 {
|
||||
h = 16
|
||||
}
|
||||
|
||||
// Resize using high-quality bicubic interpolation (matches diffusers' default lanczos)
|
||||
resized := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
draw.CatmullRom.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
return resized, w, h
|
||||
}
|
||||
|
||||
// ImageToTensor converts an image to a tensor in [-1, 1] range with shape [1, C, H, W].
|
||||
func ImageToTensor(img image.Image) *mlx.Array {
|
||||
bounds := img.Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
|
||||
// Convert to float32 array in NCHW format [1, 3, H, W] with values in [-1, 1]
|
||||
data := make([]float32, 3*h*w)
|
||||
|
||||
for y := 0; y < h; y++ {
|
||||
for x := 0; x < w; x++ {
|
||||
r, g, b, _ := img.At(x+bounds.Min.X, y+bounds.Min.Y).RGBA()
|
||||
// RGBA returns 16-bit values, convert to [-1, 1]
|
||||
data[0*h*w+y*w+x] = float32(r>>8)/127.5 - 1.0
|
||||
data[1*h*w+y*w+x] = float32(g>>8)/127.5 - 1.0
|
||||
data[2*h*w+y*w+x] = float32(b>>8)/127.5 - 1.0
|
||||
}
|
||||
}
|
||||
|
||||
arr := mlx.NewArrayFloat32(data, []int32{1, 3, int32(h), int32(w)})
|
||||
return arr
|
||||
}
|
||||
|
||||
// ImageCondTokens holds encoded reference image tokens.
|
||||
type ImageCondTokens struct {
|
||||
Tokens *mlx.Array // [1, total_tokens, C] - concatenated reference tokens
|
||||
}
|
||||
|
||||
// EncodeImageRefs encodes reference images using the VAE.
|
||||
func (m *Model) EncodeImageRefs(images []image.Image) (*ImageCondTokens, error) {
|
||||
if len(images) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Limit reference images to reduce attention memory
|
||||
limitPixels := MaxRefPixels
|
||||
if len(images) > 1 {
|
||||
limitPixels = MaxRefPixels / 2
|
||||
}
|
||||
|
||||
var allTokens []*mlx.Array
|
||||
|
||||
for _, img := range images {
|
||||
// Prepare image (resize, crop to multiple of 16)
|
||||
prepared, prepW, prepH := PrepareImage(img, limitPixels)
|
||||
fmt.Printf(" Encoding %dx%d image... ", prepW, prepH)
|
||||
|
||||
// Convert to tensor [-1, 1]
|
||||
tensor := ImageToTensor(prepared)
|
||||
|
||||
// Encode with VAE - returns [1, L, 128]
|
||||
encoded := m.VAE.EncodeImage(tensor)
|
||||
squeezed := mlx.Squeeze(encoded, 0) // [L, C]
|
||||
|
||||
// Defer eval - will be done with other setup arrays
|
||||
allTokens = append(allTokens, squeezed)
|
||||
fmt.Println("✓")
|
||||
}
|
||||
|
||||
// For single image, just add batch dimension directly
|
||||
// For multiple images, concatenate first
|
||||
var tokens *mlx.Array
|
||||
if len(allTokens) == 1 {
|
||||
tokens = mlx.ExpandDims(allTokens[0], 0) // [1, L, C]
|
||||
} else {
|
||||
tokens = mlx.Concatenate(allTokens, 0) // [total_L, C]
|
||||
tokens = mlx.ExpandDims(tokens, 0) // [1, total_L, C]
|
||||
}
|
||||
|
||||
return &ImageCondTokens{Tokens: tokens}, nil
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// RoPEConfig holds 4D RoPE configuration for Flux2
|
||||
type RoPEConfig struct {
|
||||
Theta int32 // 2000 for Klein
|
||||
AxesDims []int32 // [32, 32, 32, 32] - dimensions for T, H, W, L axes
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE cos/sin values
|
||||
type RoPECache struct {
|
||||
Cos *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
|
||||
Sin *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
|
||||
TextLen int32 // Length of text sequence
|
||||
ImageLen int32 // Length of image sequence
|
||||
}
|
||||
|
||||
// PrepareTextIDs creates position IDs for text tokens.
|
||||
// Text tokens use: T=0, H=0, W=0, L=0..seqLen-1
|
||||
// Returns: [seqLen, 4]
|
||||
func PrepareTextIDs(seqLen int32) *mlx.Array {
|
||||
ids := make([]float32, seqLen*4)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
idx := i * 4
|
||||
ids[idx+0] = 0 // T = 0
|
||||
ids[idx+1] = 0 // H = 0
|
||||
ids[idx+2] = 0 // W = 0
|
||||
ids[idx+3] = float32(i) // L = sequence position
|
||||
}
|
||||
return mlx.NewArray(ids, []int32{seqLen, 4})
|
||||
}
|
||||
|
||||
// PrepareLatentIDs creates position IDs for image latent tokens.
|
||||
// Latent tokens use: T=0, H=0..height-1, W=0..width-1, L=0
|
||||
// The latents are in row-major order (H then W).
|
||||
// Returns: [height*width, 4]
|
||||
func PrepareLatentIDs(height, width int32) *mlx.Array {
|
||||
seqLen := height * width
|
||||
ids := make([]float32, seqLen*4)
|
||||
idx := 0
|
||||
for h := int32(0); h < height; h++ {
|
||||
for w := int32(0); w < width; w++ {
|
||||
ids[idx*4+0] = 0 // T = 0
|
||||
ids[idx*4+1] = float32(h) // H = row
|
||||
ids[idx*4+2] = float32(w) // W = column
|
||||
ids[idx*4+3] = 0 // L = 0
|
||||
idx++
|
||||
}
|
||||
}
|
||||
return mlx.NewArray(ids, []int32{seqLen, 4})
|
||||
}
|
||||
|
||||
// PrepareImageIDs creates position IDs for reference image tokens (used in editing).
|
||||
// Reference images use: T=scale*(i+1), H=0..h-1, W=0..w-1, L=0
|
||||
// where i is the image index (0, 1, 2, ...) and scale separates images in T dimension.
|
||||
// Returns: [total_tokens, 4]
|
||||
func PrepareImageIDs(imageHeights, imageWidths []int32, scale int32) *mlx.Array {
|
||||
// Calculate total tokens
|
||||
totalTokens := int32(0)
|
||||
for i := range imageHeights {
|
||||
totalTokens += imageHeights[i] * imageWidths[i]
|
||||
}
|
||||
|
||||
ids := make([]float32, totalTokens*4)
|
||||
idx := int32(0)
|
||||
for imgIdx, h := range imageHeights {
|
||||
w := imageWidths[imgIdx]
|
||||
tValue := float32(scale * int32(imgIdx+1))
|
||||
for hi := int32(0); hi < h; hi++ {
|
||||
for wi := int32(0); wi < w; wi++ {
|
||||
ids[idx*4+0] = tValue // T = scale * (imgIdx + 1)
|
||||
ids[idx*4+1] = float32(hi) // H = row
|
||||
ids[idx*4+2] = float32(wi) // W = column
|
||||
ids[idx*4+3] = 0 // L = 0
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
return mlx.NewArray(ids, []int32{totalTokens, 4})
|
||||
}
|
||||
|
||||
// ComputeRoPE computes cos and sin for 4D rotary position embeddings.
|
||||
// ids: [L, 4] with (T, H, W, L) coordinates
|
||||
// axesDims: [32, 32, 32, 32] - each axis has this many dimensions (total = head_dim = 128)
|
||||
// theta: base frequency (2000 for Klein)
|
||||
// Returns: cos, sin each [1, L, 1, head_dim] with repeat_interleave applied
|
||||
func ComputeRoPE(ids *mlx.Array, axesDims []int32, theta int32) (*mlx.Array, *mlx.Array) {
|
||||
shape := ids.Shape()
|
||||
seqLen := shape[0]
|
||||
|
||||
// Compute total head dim (sum of all axes dims)
|
||||
headDim := int32(0)
|
||||
for _, d := range axesDims {
|
||||
headDim += d
|
||||
}
|
||||
|
||||
// Extract each coordinate dimension
|
||||
// ids[:, 0] = T, ids[:, 1] = H, ids[:, 2] = W, ids[:, 3] = L
|
||||
posT := mlx.Slice(ids, []int32{0, 0}, []int32{seqLen, 1}) // [L, 1]
|
||||
posH := mlx.Slice(ids, []int32{0, 1}, []int32{seqLen, 2}) // [L, 1]
|
||||
posW := mlx.Slice(ids, []int32{0, 2}, []int32{seqLen, 3}) // [L, 1]
|
||||
posL := mlx.Slice(ids, []int32{0, 3}, []int32{seqLen, 4}) // [L, 1]
|
||||
|
||||
// Compute frequencies for each axis
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
cosArrs := make([]*mlx.Array, 4)
|
||||
sinArrs := make([]*mlx.Array, 4)
|
||||
positions := []*mlx.Array{posT, posH, posW, posL}
|
||||
|
||||
for i, axisDim := range axesDims {
|
||||
half := axisDim / 2
|
||||
|
||||
// Create frequency array for this axis: theta^(-2j/dim) for j=0..half-1
|
||||
// This matches diffusers: 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
|
||||
freqs := make([]float32, half)
|
||||
for j := int32(0); j < half; j++ {
|
||||
freqs[j] = float32(math.Exp(float64(-logTheta * float32(2*j) / float32(axisDim))))
|
||||
}
|
||||
freqArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
// Compute pos * freq -> [L, half]
|
||||
posExpanded := positions[i] // [L, 1]
|
||||
args := mlx.Mul(posExpanded, freqArr) // [L, half]
|
||||
|
||||
// Compute cos and sin for this axis
|
||||
cosAxis := mlx.Cos(args) // [L, half]
|
||||
sinAxis := mlx.Sin(args) // [L, half]
|
||||
|
||||
// repeat_interleave(2): [c0, c1, ...] -> [c0, c0, c1, c1, ...]
|
||||
// Reshape [L, half] -> [L, half, 1], tile to [L, half, 2], reshape to [L, axisDim]
|
||||
cosAxis = mlx.ExpandDims(cosAxis, 2) // [L, half, 1]
|
||||
cosAxis = mlx.Tile(cosAxis, []int32{1, 1, 2}) // [L, half, 2]
|
||||
cosAxis = mlx.Reshape(cosAxis, seqLen, axisDim) // [L, axisDim]
|
||||
|
||||
sinAxis = mlx.ExpandDims(sinAxis, 2)
|
||||
sinAxis = mlx.Tile(sinAxis, []int32{1, 1, 2})
|
||||
sinAxis = mlx.Reshape(sinAxis, seqLen, axisDim)
|
||||
|
||||
cosArrs[i] = cosAxis
|
||||
sinArrs[i] = sinAxis
|
||||
}
|
||||
|
||||
// Concatenate all axes: [L, headDim]
|
||||
cos := mlx.Concatenate(cosArrs, 1)
|
||||
sin := mlx.Concatenate(sinArrs, 1)
|
||||
|
||||
// Reshape to [1, L, 1, headDim] for broadcasting with attention
|
||||
cos = mlx.Reshape(cos, 1, seqLen, 1, headDim)
|
||||
sin = mlx.Reshape(sin, 1, seqLen, 1, headDim)
|
||||
|
||||
return cos, sin
|
||||
}
|
||||
|
||||
// ApplyRoPE4D applies 4D rotary position embeddings to queries and keys.
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// cos, sin: [1, L, 1, head_dim] (with repeat_interleave applied)
|
||||
// Returns: x with RoPE applied
|
||||
// Matches diffusers apply_rotary_emb with use_real=True, use_real_unbind_dim=-1
|
||||
func ApplyRoPE4D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
half := headDim / 2
|
||||
|
||||
// Reshape x to [B, L, nheads, half, 2] and split into real/imag
|
||||
xReshaped := mlx.Reshape(x, B, L, nheads, half, 2)
|
||||
|
||||
// Extract real (index 0) and imag (index 1) parts
|
||||
xReal := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, half, 1})
|
||||
xImag := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, half, 2})
|
||||
xReal = mlx.Squeeze(xReal, 4) // [B, L, nheads, half]
|
||||
xImag = mlx.Squeeze(xImag, 4) // [B, L, nheads, half]
|
||||
|
||||
// x_rotated = stack([-x_imag, x_real], dim=-1).flatten(-2)
|
||||
// This creates [-x_imag[0], x_real[0], -x_imag[1], x_real[1], ...]
|
||||
negXImag := mlx.Neg(xImag)
|
||||
negXImag = mlx.ExpandDims(negXImag, 4) // [B, L, nheads, half, 1]
|
||||
xReal = mlx.ExpandDims(xReal, 4) // [B, L, nheads, half, 1]
|
||||
xRotated := mlx.Concatenate([]*mlx.Array{negXImag, xReal}, 4) // [B, L, nheads, half, 2]
|
||||
xRotated = mlx.Reshape(xRotated, B, L, nheads, headDim) // [B, L, nheads, headDim]
|
||||
|
||||
// out = x * cos + x_rotated * sin
|
||||
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(xRotated, sin))
|
||||
}
|
||||
|
||||
// PrepareRoPECache creates RoPE cache for text + noise, optionally with reference images.
|
||||
// textLen: number of text tokens
|
||||
// noiseH, noiseW: dimensions of the noise latent in patch tokens
|
||||
// axesDims: [32, 32, 32, 32]
|
||||
// theta: 2000
|
||||
// refHeights, refWidths: optional reference image dimensions (pass nil/empty for no images)
|
||||
// scale: time coordinate offset between reference images (e.g., 10)
|
||||
func PrepareRoPECache(textLen, noiseH, noiseW int32, axesDims []int32, theta int32, refHeights, refWidths []int32, scale int32) *RoPECache {
|
||||
textIDs := PrepareTextIDs(textLen)
|
||||
noiseIDs := PrepareLatentIDs(noiseH, noiseW)
|
||||
|
||||
var allIDs *mlx.Array
|
||||
imageLen := noiseH * noiseW
|
||||
|
||||
if len(refHeights) > 0 {
|
||||
refIDs := PrepareImageIDs(refHeights, refWidths, scale)
|
||||
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs, refIDs}, 0)
|
||||
for i := range refHeights {
|
||||
imageLen += refHeights[i] * refWidths[i]
|
||||
}
|
||||
} else {
|
||||
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs}, 0)
|
||||
}
|
||||
|
||||
cos, sin := ComputeRoPE(allIDs, axesDims, theta)
|
||||
cos = mlx.ToBFloat16(cos)
|
||||
sin = mlx.ToBFloat16(sin)
|
||||
|
||||
return &RoPECache{Cos: cos, Sin: sin, TextLen: textLen, ImageLen: imageLen}
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SchedulerConfig holds Flow-Match scheduler configuration
|
||||
type SchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
Shift float32 `json:"shift"` // 3.0 for Klein
|
||||
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
|
||||
TimeShiftType string `json:"time_shift_type"` // "exponential" or "linear"
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns default config for Klein
|
||||
func DefaultSchedulerConfig() *SchedulerConfig {
|
||||
return &SchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
Shift: 3.0, // Klein uses 3.0
|
||||
UseDynamicShifting: true,
|
||||
TimeShiftType: "exponential",
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements the Flow-Match Euler discrete scheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *SchedulerConfig
|
||||
Timesteps []float32 // Discretized timesteps (t from 1 to 0)
|
||||
Sigmas []float32 // Noise levels at each timestep
|
||||
NumSteps int // Number of inference steps
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
func (s *FlowMatchScheduler) SetTimesteps(numSteps int) {
|
||||
s.SetTimestepsWithMu(numSteps, 0)
|
||||
}
|
||||
|
||||
// SetTimestepsWithMu sets up scheduler matching diffusers set_timesteps(sigmas=..., mu=...)
|
||||
func (s *FlowMatchScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// diffusers: sigmas = linspace(1, 1/num_steps, num_steps)
|
||||
// Then applies time shift, appends 0.0 at end
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
|
||||
for i := 0; i < numSteps; i++ {
|
||||
// linspace(1, 1/num_steps, num_steps)
|
||||
var sigma float32
|
||||
if numSteps == 1 {
|
||||
sigma = 1.0
|
||||
} else {
|
||||
sigma = 1.0 - float32(i)/float32(numSteps-1)*(1.0-1.0/float32(numSteps))
|
||||
}
|
||||
|
||||
// Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShifting && mu != 0 {
|
||||
sigma = s.timeShift(mu, sigma)
|
||||
} else {
|
||||
// If not dynamic shifting, apply fixed shift scaling like diffusers
|
||||
shift := s.Config.Shift
|
||||
sigma = shift * sigma / (1 + (shift-1)*sigma)
|
||||
}
|
||||
s.Sigmas[i] = sigma
|
||||
}
|
||||
// Append terminal zero
|
||||
s.Sigmas[numSteps] = 0.0
|
||||
|
||||
// Timesteps scaled to training range (matches diffusers: timesteps = sigmas * num_train_timesteps)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
for i, v := range s.Sigmas {
|
||||
s.Timesteps[i] = v * float32(s.Config.NumTrainTimesteps)
|
||||
}
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift
|
||||
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
if s.Config.TimeShiftType == "linear" {
|
||||
return mu / (mu + (1.0/t-1.0))
|
||||
}
|
||||
// Default: exponential
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
dt := sigmaNext - sigma
|
||||
|
||||
// Upcast to float32 for precision (matches diffusers)
|
||||
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
|
||||
outputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
|
||||
|
||||
scaledOutput := mlx.MulScalar(outputF32, dt)
|
||||
result := mlx.Add(sampleF32, scaledOutput)
|
||||
|
||||
// Cast back to bfloat16
|
||||
return mlx.ToBFloat16(result)
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
|
||||
}
|
||||
|
||||
// CalculateShift computes the mu shift value for dynamic scheduling
|
||||
// Matches diffusers compute_empirical_mu function
|
||||
func CalculateShift(imgSeqLen int32, numSteps int) float32 {
|
||||
a1, b1 := float32(8.73809524e-05), float32(1.89833333)
|
||||
a2, b2 := float32(0.00016927), float32(0.45666666)
|
||||
|
||||
seqLen := float32(imgSeqLen)
|
||||
|
||||
if imgSeqLen > 4300 {
|
||||
return a2*seqLen + b2
|
||||
}
|
||||
|
||||
m200 := a2*seqLen + b2
|
||||
m10 := a1*seqLen + b1
|
||||
|
||||
a := (m200 - m10) / 190.0
|
||||
b := m200 - 200.0*a
|
||||
return a*float32(numSteps) + b
|
||||
}
|
||||
@@ -1,562 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Flux2 transformer configuration
|
||||
type TransformerConfig struct {
|
||||
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
|
||||
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
|
||||
Eps float32 `json:"eps"` // 1e-6
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
|
||||
InChannels int32 `json:"in_channels"` // 128
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
|
||||
MLPRatio float32 `json:"mlp_ratio"` // 3.0
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
|
||||
NumLayers int32 `json:"num_layers"` // 5
|
||||
NumSingleLayers int32 `json:"num_single_layers"` // 20
|
||||
PatchSize int32 `json:"patch_size"` // 1
|
||||
RopeTheta int32 `json:"rope_theta"` // 2000
|
||||
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
|
||||
}
|
||||
|
||||
// Computed dimensions
|
||||
func (c *TransformerConfig) InnerDim() int32 {
|
||||
return c.NumAttentionHeads * c.AttentionHeadDim // 24 * 128 = 3072
|
||||
}
|
||||
|
||||
func (c *TransformerConfig) MLPHiddenDim() int32 {
|
||||
return int32(float32(c.InnerDim()) * c.MLPRatio) // 3072 * 3.0 = 9216
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates timestep embeddings
|
||||
// Weight names: time_guidance_embed.timestep_embedder.linear_1.weight, linear_2.weight
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 nn.LinearLayer `weight:"linear_1"`
|
||||
Linear2 nn.LinearLayer `weight:"linear_2"`
|
||||
EmbedDim int32 // 256
|
||||
}
|
||||
|
||||
// Forward creates sinusoidal embeddings and projects them
|
||||
func (t *TimestepEmbedder) Forward(timesteps *mlx.Array) *mlx.Array {
|
||||
half := t.EmbedDim / 2
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
// timesteps: [B] -> [B, 1]
|
||||
tExpanded := mlx.ExpandDims(timesteps, 1)
|
||||
// args: [B, half]
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
|
||||
// [cos(args), sin(args)] -> [B, embed_dim]
|
||||
sinEmbed := mlx.Concatenate([]*mlx.Array{mlx.Cos(args), mlx.Sin(args)}, 1)
|
||||
|
||||
// MLP: linear_1 -> silu -> linear_2
|
||||
h := t.Linear1.Forward(sinEmbed)
|
||||
h = mlx.SiLU(h)
|
||||
return t.Linear2.Forward(h)
|
||||
}
|
||||
|
||||
// TimeGuidanceEmbed wraps the timestep embedder
|
||||
// Weight names: time_guidance_embed.timestep_embedder.*
|
||||
type TimeGuidanceEmbed struct {
|
||||
TimestepEmbedder *TimestepEmbedder `weight:"timestep_embedder"`
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings
|
||||
func (t *TimeGuidanceEmbed) Forward(timesteps *mlx.Array) *mlx.Array {
|
||||
return t.TimestepEmbedder.Forward(timesteps)
|
||||
}
|
||||
|
||||
// Modulation computes adaptive modulation parameters
|
||||
// Weight names: double_stream_modulation_img.linear.weight, etc.
|
||||
type Modulation struct {
|
||||
Linear nn.LinearLayer `weight:"linear"`
|
||||
}
|
||||
|
||||
// Forward computes modulation parameters
|
||||
func (m *Modulation) Forward(temb *mlx.Array) *mlx.Array {
|
||||
h := mlx.SiLU(temb)
|
||||
return m.Linear.Forward(h)
|
||||
}
|
||||
|
||||
// TransformerBlockAttn implements dual-stream attention
|
||||
// Weight names: transformer_blocks.N.attn.*
|
||||
type TransformerBlockAttn struct {
|
||||
// Image stream (separate Q, K, V projections)
|
||||
ToQ nn.LinearLayer `weight:"to_q"`
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
// Note: to_out has .0 suffix in weights, handled specially
|
||||
ToOut0 nn.LinearLayer `weight:"to_out.0"`
|
||||
|
||||
// Text stream (add_ projections)
|
||||
AddQProj nn.LinearLayer `weight:"add_q_proj"`
|
||||
AddKProj nn.LinearLayer `weight:"add_k_proj"`
|
||||
AddVProj nn.LinearLayer `weight:"add_v_proj"`
|
||||
ToAddOut nn.LinearLayer `weight:"to_add_out"`
|
||||
|
||||
// QK norms for image stream
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"`
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
|
||||
// QK norms for text stream (added)
|
||||
NormAddedQ *mlx.Array `weight:"norm_added_q.weight"`
|
||||
NormAddedK *mlx.Array `weight:"norm_added_k.weight"`
|
||||
}
|
||||
|
||||
// FeedForward implements SwiGLU MLP
|
||||
// Weight names: transformer_blocks.N.ff.linear_in.weight, linear_out.weight
|
||||
type FeedForward struct {
|
||||
LinearIn nn.LinearLayer `weight:"linear_in"`
|
||||
LinearOut nn.LinearLayer `weight:"linear_out"`
|
||||
}
|
||||
|
||||
// Forward applies SwiGLU MLP
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
// LinearIn outputs 2x hidden dim for SwiGLU
|
||||
h := ff.LinearIn.Forward(x)
|
||||
shape := h.Shape()
|
||||
half := shape[len(shape)-1] / 2
|
||||
|
||||
// Split into gate and up
|
||||
gate := mlx.Slice(h, []int32{0, 0, 0}, []int32{shape[0], shape[1], half})
|
||||
up := mlx.Slice(h, []int32{0, 0, half}, []int32{shape[0], shape[1], shape[2]})
|
||||
|
||||
// SwiGLU: silu(gate) * up
|
||||
h = mlx.Mul(mlx.SiLU(gate), up)
|
||||
return ff.LinearOut.Forward(h)
|
||||
}
|
||||
|
||||
// TransformerBlock implements a dual-stream transformer block
|
||||
// Weight names: transformer_blocks.N.*
|
||||
type TransformerBlock struct {
|
||||
Attn *TransformerBlockAttn `weight:"attn"`
|
||||
FF *FeedForward `weight:"ff"`
|
||||
FFContext *FeedForward `weight:"ff_context"`
|
||||
|
||||
// Config (set after loading)
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// Forward applies the dual-stream block
|
||||
// imgHidden: [B, imgLen, dim]
|
||||
// txtHidden: [B, txtLen, dim]
|
||||
// imgMod, txtMod: modulation params [B, 6*dim] each
|
||||
// cos, sin: RoPE values
|
||||
func (block *TransformerBlock) Forward(imgHidden, txtHidden *mlx.Array, imgMod, txtMod *mlx.Array, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
imgShape := imgHidden.Shape()
|
||||
B := imgShape[0]
|
||||
imgLen := imgShape[1]
|
||||
dim := imgShape[2]
|
||||
txtLen := txtHidden.Shape()[1]
|
||||
|
||||
// Parse modulation: 6 params each (shift1, scale1, gate1, shift2, scale2, gate2)
|
||||
imgShift1, imgScale1, imgGate1 := parseModulation3(imgMod, dim, 0)
|
||||
imgShift2, imgScale2, imgGate2 := parseModulation3(imgMod, dim, 3)
|
||||
txtShift1, txtScale1, txtGate1 := parseModulation3(txtMod, dim, 0)
|
||||
txtShift2, txtScale2, txtGate2 := parseModulation3(txtMod, dim, 3)
|
||||
|
||||
// === Attention branch ===
|
||||
// Modulate inputs
|
||||
imgNorm := modulateLayerNorm(imgHidden, imgShift1, imgScale1)
|
||||
txtNorm := modulateLayerNorm(txtHidden, txtShift1, txtScale1)
|
||||
|
||||
// Compute Q, K, V for image stream (separate projections)
|
||||
imgQ := block.Attn.ToQ.Forward(imgNorm)
|
||||
imgK := block.Attn.ToK.Forward(imgNorm)
|
||||
imgV := block.Attn.ToV.Forward(imgNorm)
|
||||
|
||||
// Compute Q, K, V for text stream (add_ projections)
|
||||
txtQ := block.Attn.AddQProj.Forward(txtNorm)
|
||||
txtK := block.Attn.AddKProj.Forward(txtNorm)
|
||||
txtV := block.Attn.AddVProj.Forward(txtNorm)
|
||||
|
||||
// Reshape for attention: [B, L, dim] -> [B, L, nheads, headDim]
|
||||
imgQ = mlx.Reshape(imgQ, B, imgLen, block.NHeads, block.HeadDim)
|
||||
imgK = mlx.Reshape(imgK, B, imgLen, block.NHeads, block.HeadDim)
|
||||
imgV = mlx.Reshape(imgV, B, imgLen, block.NHeads, block.HeadDim)
|
||||
txtQ = mlx.Reshape(txtQ, B, txtLen, block.NHeads, block.HeadDim)
|
||||
txtK = mlx.Reshape(txtK, B, txtLen, block.NHeads, block.HeadDim)
|
||||
txtV = mlx.Reshape(txtV, B, txtLen, block.NHeads, block.HeadDim)
|
||||
|
||||
// Apply QK norm (RMSNorm with learned scale)
|
||||
imgQ = applyQKNorm(imgQ, block.Attn.NormQ)
|
||||
imgK = applyQKNorm(imgK, block.Attn.NormK)
|
||||
txtQ = applyQKNorm(txtQ, block.Attn.NormAddedQ)
|
||||
txtK = applyQKNorm(txtK, block.Attn.NormAddedK)
|
||||
|
||||
// Concatenate for joint attention: text first, then image
|
||||
q := mlx.Concatenate([]*mlx.Array{txtQ, imgQ}, 1)
|
||||
k := mlx.Concatenate([]*mlx.Array{txtK, imgK}, 1)
|
||||
v := mlx.Concatenate([]*mlx.Array{txtV, imgV}, 1)
|
||||
|
||||
// Apply RoPE
|
||||
q = ApplyRoPE4D(q, cos, sin)
|
||||
k = ApplyRoPE4D(k, cos, sin)
|
||||
|
||||
// Transpose for SDPA: [B, nheads, L, headDim]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// Scaled dot-product attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
|
||||
|
||||
// Transpose back: [B, L, nheads, headDim]
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
|
||||
// Split back into txt and img
|
||||
totalLen := txtLen + imgLen
|
||||
txtOut := mlx.Slice(out, []int32{0, 0, 0, 0}, []int32{B, txtLen, block.NHeads, block.HeadDim})
|
||||
imgOut := mlx.Slice(out, []int32{0, txtLen, 0, 0}, []int32{B, totalLen, block.NHeads, block.HeadDim})
|
||||
|
||||
// Reshape and project
|
||||
txtOut = mlx.Reshape(txtOut, B, txtLen, dim)
|
||||
imgOut = mlx.Reshape(imgOut, B, imgLen, dim)
|
||||
txtOut = block.Attn.ToAddOut.Forward(txtOut)
|
||||
imgOut = block.Attn.ToOut0.Forward(imgOut)
|
||||
|
||||
// Apply gates and residual
|
||||
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate1, imgOut))
|
||||
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate1, txtOut))
|
||||
|
||||
// === MLP branch ===
|
||||
imgNorm = modulateLayerNorm(imgHidden, imgShift2, imgScale2)
|
||||
txtNorm = modulateLayerNorm(txtHidden, txtShift2, txtScale2)
|
||||
|
||||
imgFFOut := block.FF.Forward(imgNorm)
|
||||
txtFFOut := block.FFContext.Forward(txtNorm)
|
||||
|
||||
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate2, imgFFOut))
|
||||
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate2, txtFFOut))
|
||||
|
||||
return imgHidden, txtHidden
|
||||
}
|
||||
|
||||
// SingleTransformerBlockAttn implements attention for single-stream blocks
|
||||
// Weight names: single_transformer_blocks.N.attn.*
|
||||
type SingleTransformerBlockAttn struct {
|
||||
ToQKVMlpProj nn.LinearLayer `weight:"to_qkv_mlp_proj"` // Fused QKV + MLP input
|
||||
ToOut nn.LinearLayer `weight:"to_out"` // Fused attn_out + MLP out
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"`
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
}
|
||||
|
||||
// SingleTransformerBlock implements a single-stream transformer block
|
||||
// Weight names: single_transformer_blocks.N.*
|
||||
type SingleTransformerBlock struct {
|
||||
Attn *SingleTransformerBlockAttn `weight:"attn"`
|
||||
|
||||
// Config
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
InnerDim int32
|
||||
MLPHidDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// Forward applies the single-stream block
|
||||
// x: [B, L, dim] concatenated text+image
|
||||
// mod: modulation [B, 3*dim]
|
||||
func (block *SingleTransformerBlock) Forward(x *mlx.Array, mod *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
dim := shape[2]
|
||||
|
||||
// Parse modulation: (shift, scale, gate)
|
||||
shift, scale, gate := parseModulation3(mod, dim, 0)
|
||||
|
||||
// Modulate input
|
||||
h := modulateLayerNorm(x, shift, scale)
|
||||
|
||||
// Fused projection: QKV + MLP gate/up
|
||||
// linear1 outputs: [q, k, v, mlp_gate, mlp_up] = [dim, dim, dim, mlpHid, mlpHid]
|
||||
qkvMlp := block.Attn.ToQKVMlpProj.Forward(h)
|
||||
|
||||
// Split: first 3*dim is QKV, rest is MLP
|
||||
qkvDim := 3 * block.InnerDim
|
||||
qkv := mlx.Slice(qkvMlp, []int32{0, 0, 0}, []int32{B, L, qkvDim})
|
||||
mlpIn := mlx.Slice(qkvMlp, []int32{0, 0, qkvDim}, []int32{B, L, qkvMlp.Shape()[2]})
|
||||
|
||||
// Split QKV
|
||||
q, k, v := splitQKV(qkv, B, L, block.InnerDim)
|
||||
|
||||
// Reshape for attention
|
||||
q = mlx.Reshape(q, B, L, block.NHeads, block.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, block.NHeads, block.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, block.NHeads, block.HeadDim)
|
||||
|
||||
// QK norm
|
||||
q = applyQKNorm(q, block.Attn.NormQ)
|
||||
k = applyQKNorm(k, block.Attn.NormK)
|
||||
|
||||
// Apply RoPE
|
||||
q = ApplyRoPE4D(q, cos, sin)
|
||||
k = ApplyRoPE4D(k, cos, sin)
|
||||
|
||||
// Transpose for SDPA
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
attnOut := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
|
||||
|
||||
// Transpose back and reshape
|
||||
attnOut = mlx.Transpose(attnOut, 0, 2, 1, 3)
|
||||
attnOut = mlx.Reshape(attnOut, B, L, block.InnerDim)
|
||||
|
||||
// MLP: SwiGLU
|
||||
mlpShape := mlpIn.Shape()
|
||||
half := mlpShape[2] / 2
|
||||
mlpGate := mlx.Slice(mlpIn, []int32{0, 0, 0}, []int32{B, L, half})
|
||||
mlpUp := mlx.Slice(mlpIn, []int32{0, 0, half}, []int32{B, L, mlpShape[2]})
|
||||
mlpOut := mlx.Mul(mlx.SiLU(mlpGate), mlpUp)
|
||||
|
||||
// Concatenate attention and MLP for fused output
|
||||
combined := mlx.Concatenate([]*mlx.Array{attnOut, mlpOut}, 2)
|
||||
|
||||
// Output projection
|
||||
out := block.Attn.ToOut.Forward(combined)
|
||||
|
||||
// Apply gate and residual
|
||||
return mlx.Add(x, mlx.Mul(gate, out))
|
||||
}
|
||||
|
||||
// NormOut implements the output normalization with modulation
|
||||
// Weight names: norm_out.linear.weight
|
||||
type NormOut struct {
|
||||
Linear nn.LinearLayer `weight:"linear"`
|
||||
}
|
||||
|
||||
// Forward computes final modulated output
|
||||
func (n *NormOut) Forward(x *mlx.Array, temb *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
dim := shape[2]
|
||||
|
||||
// Modulation: temb -> silu -> linear -> [shift, scale]
|
||||
mod := mlx.SiLU(temb)
|
||||
mod = n.Linear.Forward(mod)
|
||||
|
||||
// Split into scale and shift (diffusers order: scale first, shift second)
|
||||
scale := mlx.Slice(mod, []int32{0, 0}, []int32{B, dim})
|
||||
shift := mlx.Slice(mod, []int32{0, dim}, []int32{B, 2 * dim})
|
||||
shift = mlx.ExpandDims(shift, 1)
|
||||
scale = mlx.ExpandDims(scale, 1)
|
||||
|
||||
// Modulate with RMSNorm
|
||||
return modulateLayerNorm(x, shift, scale)
|
||||
}
|
||||
|
||||
// Flux2Transformer2DModel is the main Flux2 transformer
|
||||
// Weight names at top level: time_guidance_embed.*, double_stream_modulation_*.*, etc.
|
||||
type Flux2Transformer2DModel struct {
|
||||
// Timestep embedding
|
||||
TimeGuidanceEmbed *TimeGuidanceEmbed `weight:"time_guidance_embed"`
|
||||
|
||||
// Shared modulation
|
||||
DoubleStreamModulationImg *Modulation `weight:"double_stream_modulation_img"`
|
||||
DoubleStreamModulationTxt *Modulation `weight:"double_stream_modulation_txt"`
|
||||
SingleStreamModulation *Modulation `weight:"single_stream_modulation"`
|
||||
|
||||
// Embedders
|
||||
XEmbedder nn.LinearLayer `weight:"x_embedder"`
|
||||
ContextEmbedder nn.LinearLayer `weight:"context_embedder"`
|
||||
|
||||
// Transformer blocks
|
||||
TransformerBlocks []*TransformerBlock `weight:"transformer_blocks"`
|
||||
SingleTransformerBlocks []*SingleTransformerBlock `weight:"single_transformer_blocks"`
|
||||
|
||||
// Output
|
||||
NormOut *NormOut `weight:"norm_out"`
|
||||
ProjOut nn.LinearLayer `weight:"proj_out"`
|
||||
|
||||
*TransformerConfig
|
||||
}
|
||||
|
||||
// Load loads the Flux2 transformer from ollama blob storage.
|
||||
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading transformer... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg TransformerConfig
|
||||
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.TransformerConfig = &cfg
|
||||
|
||||
// Initialize slices
|
||||
m.TransformerBlocks = make([]*TransformerBlock, cfg.NumLayers)
|
||||
m.SingleTransformerBlocks = make([]*SingleTransformerBlock, cfg.NumSingleLayers)
|
||||
|
||||
// Initialize TimeGuidanceEmbed with embed dim
|
||||
m.TimeGuidanceEmbed = &TimeGuidanceEmbed{
|
||||
TimestepEmbedder: &TimestepEmbedder{EmbedDim: cfg.TimestepGuidanceChannels},
|
||||
}
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *Flux2Transformer2DModel) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *Flux2Transformer2DModel) initComputedFields() {
|
||||
cfg := m.TransformerConfig
|
||||
innerDim := cfg.InnerDim()
|
||||
scale := float32(1.0 / math.Sqrt(float64(cfg.AttentionHeadDim)))
|
||||
|
||||
// Initialize transformer blocks
|
||||
for _, block := range m.TransformerBlocks {
|
||||
block.NHeads = cfg.NumAttentionHeads
|
||||
block.HeadDim = cfg.AttentionHeadDim
|
||||
block.Scale = scale
|
||||
}
|
||||
|
||||
// Initialize single transformer blocks
|
||||
for _, block := range m.SingleTransformerBlocks {
|
||||
block.NHeads = cfg.NumAttentionHeads
|
||||
block.HeadDim = cfg.AttentionHeadDim
|
||||
block.InnerDim = innerDim
|
||||
block.MLPHidDim = cfg.MLPHiddenDim()
|
||||
block.Scale = scale
|
||||
}
|
||||
}
|
||||
|
||||
// Forward runs the Flux2 transformer
|
||||
func (m *Flux2Transformer2DModel) Forward(patches, txtEmbeds *mlx.Array, timesteps *mlx.Array, rope *RoPECache) *mlx.Array {
|
||||
patchShape := patches.Shape()
|
||||
B := patchShape[0]
|
||||
imgLen := patchShape[1]
|
||||
txtLen := txtEmbeds.Shape()[1]
|
||||
|
||||
// Scale timestep to 0-1000 range (diffusers multiplies by 1000)
|
||||
scaledTimesteps := mlx.MulScalar(timesteps, 1000.0)
|
||||
|
||||
// Compute timestep embedding
|
||||
temb := m.TimeGuidanceEmbed.Forward(scaledTimesteps)
|
||||
|
||||
// Embed patches and text
|
||||
imgHidden := m.XEmbedder.Forward(patches)
|
||||
txtHidden := m.ContextEmbedder.Forward(txtEmbeds)
|
||||
|
||||
// Compute shared modulation
|
||||
imgMod := m.DoubleStreamModulationImg.Forward(temb)
|
||||
txtMod := m.DoubleStreamModulationTxt.Forward(temb)
|
||||
singleMod := m.SingleStreamModulation.Forward(temb)
|
||||
|
||||
// Double (dual-stream) blocks
|
||||
for _, block := range m.TransformerBlocks {
|
||||
imgHidden, txtHidden = block.Forward(imgHidden, txtHidden, imgMod, txtMod, rope.Cos, rope.Sin)
|
||||
}
|
||||
|
||||
// Concatenate for single-stream: text first, then image
|
||||
hidden := mlx.Concatenate([]*mlx.Array{txtHidden, imgHidden}, 1)
|
||||
|
||||
// Single-stream blocks
|
||||
for _, block := range m.SingleTransformerBlocks {
|
||||
hidden = block.Forward(hidden, singleMod, rope.Cos, rope.Sin)
|
||||
}
|
||||
|
||||
// Extract image portion
|
||||
totalLen := txtLen + imgLen
|
||||
imgOut := mlx.Slice(hidden, []int32{0, txtLen, 0}, []int32{B, totalLen, hidden.Shape()[2]})
|
||||
|
||||
// Final norm and projection
|
||||
imgOut = m.NormOut.Forward(imgOut, temb)
|
||||
return m.ProjOut.Forward(imgOut)
|
||||
}
|
||||
|
||||
// Note: QK normalization uses mlx.RMSNorm (the fast version) directly
|
||||
// See applyQKNorm function below
|
||||
|
||||
// compiledSwiGLU fuses: silu(gate) * up
|
||||
// Called 30x per step (10 in dual-stream + 20 in single-stream blocks)
|
||||
var compiledSwiGLU *mlx.CompiledFunc
|
||||
|
||||
func getCompiledSwiGLU() *mlx.CompiledFunc {
|
||||
if compiledSwiGLU == nil {
|
||||
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
gate, up := inputs[0], inputs[1]
|
||||
return []*mlx.Array{mlx.Mul(mlx.SiLU(gate), up)}
|
||||
}, true)
|
||||
}
|
||||
return compiledSwiGLU
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// parseModulation3 extracts 3 modulation params (shift, scale, gate) starting at offset
|
||||
func parseModulation3(mod *mlx.Array, dim int32, offset int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
|
||||
B := mod.Shape()[0]
|
||||
start := offset * dim
|
||||
shift := mlx.Slice(mod, []int32{0, start}, []int32{B, start + dim})
|
||||
scale := mlx.Slice(mod, []int32{0, start + dim}, []int32{B, start + 2*dim})
|
||||
gate := mlx.Slice(mod, []int32{0, start + 2*dim}, []int32{B, start + 3*dim})
|
||||
|
||||
// Expand for broadcasting [B, dim] -> [B, 1, dim]
|
||||
shift = mlx.ExpandDims(shift, 1)
|
||||
scale = mlx.ExpandDims(scale, 1)
|
||||
gate = mlx.ExpandDims(gate, 1)
|
||||
|
||||
return shift, scale, gate
|
||||
}
|
||||
|
||||
// modulateLayerNorm applies LayerNorm then shift/scale modulation
|
||||
// Diffusers uses LayerNorm(elementwise_affine=False) which centers the data
|
||||
func modulateLayerNorm(x *mlx.Array, shift, scale *mlx.Array) *mlx.Array {
|
||||
// Fast LayerNorm without learnable params
|
||||
x = mlx.LayerNorm(x, 1e-6)
|
||||
|
||||
// Modulate: x * (1 + scale) + shift
|
||||
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
|
||||
return mlx.Add(x, shift)
|
||||
}
|
||||
|
||||
// splitQKV splits a fused QKV tensor into Q, K, V
|
||||
func splitQKV(qkv *mlx.Array, B, L, dim int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0}, []int32{B, L, dim})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, dim}, []int32{B, L, 2 * dim})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 2 * dim}, []int32{B, L, 3 * dim})
|
||||
return q, k, v
|
||||
}
|
||||
|
||||
// applyQKNorm applies RMSNorm with learned scale (no bias)
|
||||
// Uses the optimized mlx_fast_rms_norm
|
||||
func applyQKNorm(x *mlx.Array, scale *mlx.Array) *mlx.Array {
|
||||
return mlx.RMSNorm(x, scale, 1e-6)
|
||||
}
|
||||
@@ -1,804 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
)
|
||||
|
||||
// VAEConfig holds AutoencoderKLFlux2 configuration
|
||||
type VAEConfig struct {
|
||||
ActFn string `json:"act_fn"` // "silu"
|
||||
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
|
||||
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
|
||||
ForceUpcast bool `json:"force_upcast"` // true
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 32
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 2
|
||||
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
PatchSize []int32 `json:"patch_size"` // [2, 2]
|
||||
SampleSize int32 `json:"sample_size"` // 1024
|
||||
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
|
||||
UseQuantConv bool `json:"use_quant_conv"` // true
|
||||
}
|
||||
|
||||
// BatchNorm2D implements 2D batch normalization with running statistics
|
||||
type BatchNorm2D struct {
|
||||
RunningMean *mlx.Array // [C]
|
||||
RunningVar *mlx.Array // [C]
|
||||
Weight *mlx.Array // [C] gamma
|
||||
Bias *mlx.Array // [C] beta
|
||||
Eps float32
|
||||
Momentum float32
|
||||
}
|
||||
|
||||
// Forward applies batch normalization (inference mode - uses running stats)
|
||||
// Input and output are in NHWC format [B, H, W, C]
|
||||
func (bn *BatchNorm2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[3]
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, 1, C]
|
||||
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
|
||||
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
|
||||
|
||||
// Normalize: (x - mean) / sqrt(var + eps)
|
||||
xNorm := mlx.Sub(x, mean)
|
||||
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
|
||||
|
||||
// Scale and shift (only if affine=True)
|
||||
if bn.Weight != nil {
|
||||
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if bn.Bias != nil {
|
||||
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Denormalize inverts the batch normalization
|
||||
// Used when decoding latents
|
||||
func (bn *BatchNorm2D) Denormalize(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[3]
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, 1, C]
|
||||
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
|
||||
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
|
||||
|
||||
// Inverse: first undo affine, then undo normalization
|
||||
// For affine=False: x_denorm = x * sqrt(var + eps) + mean
|
||||
if bn.Bias != nil {
|
||||
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
|
||||
x = mlx.Sub(x, bias)
|
||||
}
|
||||
if bn.Weight != nil {
|
||||
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
|
||||
x = mlx.Div(x, weight)
|
||||
}
|
||||
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
|
||||
x = mlx.Add(x, mean)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// GroupNormLayer implements group normalization
|
||||
// Reused from zimage package pattern
|
||||
type GroupNormLayer struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
NumGroups int32
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// Forward applies group normalization
|
||||
// Input and output are in NHWC format [B, H, W, C]
|
||||
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
// Reshape to [B, H, W, groups, C/groups]
|
||||
groupSize := C / gn.NumGroups
|
||||
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 1, true)
|
||||
mean = mlx.Mean(mean, 2, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
|
||||
sq := mlx.Square(xCentered)
|
||||
variance := mlx.Mean(sq, 1, true)
|
||||
variance = mlx.Mean(variance, 2, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back to [B, H, W, C]
|
||||
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
||||
|
||||
// Scale and shift
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Conv2D represents a 2D convolution layer (reused pattern)
|
||||
type Conv2D struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias,optional"`
|
||||
Stride int32
|
||||
Padding int32
|
||||
}
|
||||
|
||||
// Transform implements safetensors.Transformer to transpose weights from PyTorch's OIHW to MLX's OHWI.
|
||||
func (conv *Conv2D) Transform(field string, arr *mlx.Array) *mlx.Array {
|
||||
if field == "Weight" {
|
||||
return mlx.Transpose(arr, 0, 2, 3, 1)
|
||||
}
|
||||
return arr
|
||||
}
|
||||
|
||||
// Forward applies convolution (NHWC format)
|
||||
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
|
||||
|
||||
if conv.Bias != nil {
|
||||
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// ResnetBlock2D implements a ResNet block for VAE
|
||||
type ResnetBlock2D struct {
|
||||
Norm1 *GroupNormLayer `weight:"norm1"`
|
||||
Conv1 *Conv2D `weight:"conv1"`
|
||||
Norm2 *GroupNormLayer `weight:"norm2"`
|
||||
Conv2 *Conv2D `weight:"conv2"`
|
||||
ConvShortcut *Conv2D `weight:"conv_shortcut,optional"`
|
||||
}
|
||||
|
||||
// Forward applies the ResNet block
|
||||
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
h := rb.Norm1.Forward(x)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
|
||||
h = rb.Norm2.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
|
||||
if rb.ConvShortcut != nil {
|
||||
x = rb.ConvShortcut.Forward(x)
|
||||
}
|
||||
|
||||
return mlx.Add(h, x)
|
||||
}
|
||||
|
||||
// VAEAttentionBlock implements self-attention for VAE
|
||||
type VAEAttentionBlock struct {
|
||||
GroupNorm *GroupNormLayer `weight:"group_norm"`
|
||||
ToQ nn.LinearLayer `weight:"to_q"`
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
ToOut nn.LinearLayer `weight:"to_out.0"`
|
||||
}
|
||||
|
||||
// Forward applies attention (NHWC format)
|
||||
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
h := ab.GroupNorm.Forward(x)
|
||||
h = mlx.Reshape(h, B, H*W, C)
|
||||
|
||||
q := ab.ToQ.Forward(h)
|
||||
k := ab.ToK.Forward(h)
|
||||
v := ab.ToV.Forward(h)
|
||||
|
||||
q = mlx.ExpandDims(q, 1)
|
||||
k = mlx.ExpandDims(k, 1)
|
||||
v = mlx.ExpandDims(v, 1)
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
out = mlx.Squeeze(out, 1)
|
||||
|
||||
out = ab.ToOut.Forward(out)
|
||||
out = mlx.Reshape(out, B, H, W, C)
|
||||
out = mlx.Add(out, residual)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// UpDecoderBlock2D implements an upsampling decoder block
|
||||
type UpDecoderBlock2D struct {
|
||||
ResnetBlocks []*ResnetBlock2D
|
||||
Upsample *Conv2D
|
||||
}
|
||||
|
||||
// Forward applies the up decoder block
|
||||
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range ub.ResnetBlocks {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
|
||||
if ub.Upsample != nil {
|
||||
x = upsample2x(x)
|
||||
x = ub.Upsample.Forward(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2x performs 2x nearest neighbor upsampling
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
|
||||
hIdx = mlx.Reshape(hIdx, H, 1)
|
||||
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
|
||||
hIdx = mlx.Reshape(hIdx, H*2)
|
||||
|
||||
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
|
||||
wIdx = mlx.Reshape(wIdx, W, 1)
|
||||
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
|
||||
wIdx = mlx.Reshape(wIdx, W*2)
|
||||
|
||||
x = mlx.Take(x, hIdx, 1)
|
||||
x = mlx.Take(x, wIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEMidBlock is the middle block with attention
|
||||
type VAEMidBlock struct {
|
||||
Resnet1 *ResnetBlock2D
|
||||
Attention *VAEAttentionBlock
|
||||
Resnet2 *ResnetBlock2D
|
||||
}
|
||||
|
||||
// Forward applies the mid block
|
||||
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
x = mb.Resnet1.Forward(x)
|
||||
x = mb.Attention.Forward(x)
|
||||
x = mb.Resnet2.Forward(x)
|
||||
return x
|
||||
}
|
||||
|
||||
// DefaultTilingConfig returns reasonable defaults for tiled decoding
|
||||
// Matches diffusers: tile_latent_min_size=64, tile_overlap_factor=0.25
|
||||
func DefaultTilingConfig() *vae.TilingConfig {
|
||||
return vae.DefaultTilingConfig()
|
||||
}
|
||||
|
||||
// AutoencoderKLFlux2 is the Flux2 VAE with BatchNorm
|
||||
type AutoencoderKLFlux2 struct {
|
||||
Config *VAEConfig
|
||||
|
||||
// Encoder components (for image editing)
|
||||
EncoderConvIn *Conv2D
|
||||
EncoderMid *VAEMidBlock
|
||||
EncoderDown []*DownEncoderBlock2D
|
||||
EncoderNormOut *GroupNormLayer
|
||||
EncoderConvOut *Conv2D
|
||||
|
||||
// Decoder components
|
||||
DecoderConvIn *Conv2D
|
||||
DecoderMid *VAEMidBlock
|
||||
DecoderUp []*UpDecoderBlock2D
|
||||
DecoderNormOut *GroupNormLayer
|
||||
DecoderConvOut *Conv2D
|
||||
|
||||
// Quant conv layers
|
||||
QuantConv *Conv2D
|
||||
PostQuantConv *Conv2D
|
||||
|
||||
// BatchNorm for latent normalization
|
||||
LatentBN *BatchNorm2D
|
||||
|
||||
// Tiling configuration (nil = no tiling)
|
||||
Tiling *vae.TilingConfig
|
||||
}
|
||||
|
||||
// DownEncoderBlock2D implements a downsampling encoder block
|
||||
type DownEncoderBlock2D struct {
|
||||
ResnetBlocks []*ResnetBlock2D
|
||||
Downsample *Conv2D
|
||||
}
|
||||
|
||||
// Forward applies the down encoder block
|
||||
func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range db.ResnetBlocks {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
|
||||
if db.Downsample != nil {
|
||||
// Pad then conv with stride 2
|
||||
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0})
|
||||
x = db.Downsample.Forward(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Load loads the Flux2 VAE from ollama blob storage.
|
||||
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading VAE... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights, &cfg)
|
||||
}
|
||||
|
||||
// loadWeights loads VAE weights from any WeightSource
|
||||
func (m *AutoencoderKLFlux2) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
|
||||
var err error
|
||||
|
||||
// Load encoder components (for image conditioning)
|
||||
if err := m.loadEncoderWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("encoder: %w", err)
|
||||
}
|
||||
|
||||
// Load decoder conv_in
|
||||
m.DecoderConvIn = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.DecoderConvIn, weights, "decoder.conv_in"); err != nil {
|
||||
return fmt.Errorf("decoder.conv_in: %w", err)
|
||||
}
|
||||
|
||||
// Load mid block
|
||||
m.DecoderMid, err = loadVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decoder.mid_block: %w", err)
|
||||
}
|
||||
|
||||
// Load up blocks
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.DecoderUp = make([]*UpDecoderBlock2D, numBlocks)
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
hasUpsample := i < numBlocks-1
|
||||
m.DecoderUp[i], err = loadUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", prefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load decoder conv_norm_out and conv_out
|
||||
m.DecoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
|
||||
if err := safetensors.LoadModule(m.DecoderNormOut, weights, "decoder.conv_norm_out"); err != nil {
|
||||
return fmt.Errorf("decoder.conv_norm_out: %w", err)
|
||||
}
|
||||
|
||||
m.DecoderConvOut = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.DecoderConvOut, weights, "decoder.conv_out"); err != nil {
|
||||
return fmt.Errorf("decoder.conv_out: %w", err)
|
||||
}
|
||||
|
||||
// Load post_quant_conv
|
||||
if cfg.UsePostQuantConv {
|
||||
m.PostQuantConv = &Conv2D{Stride: 1, Padding: 0}
|
||||
if err := safetensors.LoadModule(m.PostQuantConv, weights, "post_quant_conv"); err != nil {
|
||||
return fmt.Errorf("post_quant_conv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load latent BatchNorm (affine=False, so no weight/bias)
|
||||
bnMean, err := weights.GetTensor("bn.running_mean")
|
||||
if err != nil {
|
||||
return fmt.Errorf("bn.running_mean: %w", err)
|
||||
}
|
||||
bnVar, err := weights.GetTensor("bn.running_var")
|
||||
if err != nil {
|
||||
return fmt.Errorf("bn.running_var: %w", err)
|
||||
}
|
||||
m.LatentBN = &BatchNorm2D{
|
||||
RunningMean: bnMean,
|
||||
RunningVar: bnVar,
|
||||
Weight: nil, // affine=False
|
||||
Bias: nil, // affine=False
|
||||
Eps: cfg.BatchNormEps,
|
||||
Momentum: cfg.BatchNormMomentum,
|
||||
}
|
||||
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadVAEMidBlock loads the mid block.
|
||||
func loadVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
||||
resnet1, err := loadResnetBlock2D(weights, prefix+".resnets.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attention, err := loadVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resnet2, err := loadResnetBlock2D(weights, prefix+".resnets.1", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &VAEMidBlock{
|
||||
Resnet1: resnet1,
|
||||
Attention: attention,
|
||||
Resnet2: resnet2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// loadResnetBlock2D loads a ResNet block.
|
||||
func loadResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
||||
block := &ResnetBlock2D{
|
||||
Norm1: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
||||
Conv1: &Conv2D{Stride: 1, Padding: 1},
|
||||
Norm2: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
||||
Conv2: &Conv2D{Stride: 1, Padding: 1},
|
||||
ConvShortcut: &Conv2D{Stride: 1, Padding: 0}, // Pre-allocate for optional loading
|
||||
}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If ConvShortcut wasn't loaded (no weights found), nil it out
|
||||
if block.ConvShortcut.Weight == nil {
|
||||
block.ConvShortcut = nil
|
||||
}
|
||||
return block, nil
|
||||
}
|
||||
|
||||
// loadVAEAttentionBlock loads an attention block using LoadModule.
|
||||
func loadVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
||||
ab := &VAEAttentionBlock{
|
||||
GroupNorm: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
||||
}
|
||||
if err := safetensors.LoadModule(ab, weights, prefix); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ab, nil
|
||||
}
|
||||
|
||||
// loadUpDecoderBlock2D loads an up decoder block.
|
||||
func loadUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
||||
resnets := make([]*ResnetBlock2D, numLayers)
|
||||
for i := int32(0); i < numLayers; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resnets[i] = resnet
|
||||
}
|
||||
|
||||
var upsample *Conv2D
|
||||
if hasUpsample {
|
||||
upsample = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(upsample, weights, prefix+".upsamplers.0.conv"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &UpDecoderBlock2D{
|
||||
ResnetBlocks: resnets,
|
||||
Upsample: upsample,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Patchify converts latents [B, C, H, W] to patches [B, H*W/4, C*4] using 2x2 patches
|
||||
// This is the inverse of the VAE's patchify for feeding to transformer
|
||||
func (vae *AutoencoderKLFlux2) Patchify(latents *mlx.Array) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
patchH := vae.Config.PatchSize[0]
|
||||
patchW := vae.Config.PatchSize[1]
|
||||
|
||||
pH := H / patchH
|
||||
pW := W / patchW
|
||||
|
||||
// [B, C, H, W] -> [B, C, pH, patchH, pW, patchW]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchH, pW, patchW)
|
||||
// [B, C, pH, patchH, pW, patchW] -> [B, pH, pW, C, patchH, patchW]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// [B, pH, pW, C, patchH, patchW] -> [B, pH*pW, C*patchH*patchW]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchH*patchW)
|
||||
}
|
||||
|
||||
// Unpatchify converts patches [B, L, C*4] back to [B, C, H, W]
|
||||
func (vae *AutoencoderKLFlux2) Unpatchify(patches *mlx.Array, pH, pW, C int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
|
||||
patchH := vae.Config.PatchSize[0]
|
||||
patchW := vae.Config.PatchSize[1]
|
||||
|
||||
// [B, pH*pW, C*patchH*patchW] -> [B, pH, pW, C, patchH, patchW]
|
||||
x := mlx.Reshape(patches, B, pH, pW, C, patchH, patchW)
|
||||
// [B, pH, pW, C, patchH, patchW] -> [B, C, pH, patchH, pW, patchW]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// [B, C, pH, patchH, pW, patchW] -> [B, C, H, W]
|
||||
H := pH * patchH
|
||||
W := pW * patchW
|
||||
return mlx.Reshape(x, B, C, H, W)
|
||||
}
|
||||
|
||||
// denormalizePatchified applies inverse batch normalization to patchified latents.
|
||||
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
|
||||
// Output: [B, L, 128] denormalized
|
||||
func (vae *AutoencoderKLFlux2) denormalizePatchified(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[2] // 128
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, C]
|
||||
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
|
||||
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
|
||||
|
||||
// Inverse BN (affine=False): x_denorm = x * sqrt(var + eps) + mean
|
||||
if vae.LatentBN.Bias != nil {
|
||||
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
|
||||
x = mlx.Sub(x, bias)
|
||||
}
|
||||
if vae.LatentBN.Weight != nil {
|
||||
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
|
||||
x = mlx.Div(x, weight)
|
||||
}
|
||||
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
|
||||
x = mlx.Add(x, mean)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Decode decodes latent patches to images.
|
||||
// If Tiling is set, uses tiled decoding to reduce memory for large images.
|
||||
// latents: [B, L, C*4] patchified latents from transformer
|
||||
// pH, pW: patch grid dimensions
|
||||
// Returns: [B, 3, H, W] image tensor
|
||||
func (v *AutoencoderKLFlux2) Decode(latents *mlx.Array, pH, pW int32) *mlx.Array {
|
||||
// Denormalize patchified latents
|
||||
z := v.denormalizePatchified(latents)
|
||||
|
||||
// Unpatchify: [B, L, C*4] -> [B, C, H, W]
|
||||
z = v.Unpatchify(z, pH, pW, v.Config.LatentChannels)
|
||||
|
||||
// Convert NCHW -> NHWC for processing
|
||||
z = mlx.Transpose(z, 0, 2, 3, 1)
|
||||
|
||||
// Use tiled decoding if enabled
|
||||
if v.Tiling != nil {
|
||||
mlx.Eval(z)
|
||||
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
|
||||
}
|
||||
|
||||
// Direct decode (no tiling)
|
||||
h := v.decodeTile(z)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
return h
|
||||
}
|
||||
|
||||
// decodeTile decodes a single latent tile to pixels (internal helper)
|
||||
// z: [B, H, W, C] latent tile in NHWC format
|
||||
// Returns: [B, H*8, W*8, 3] pixel tile in NHWC format (before clipping)
|
||||
func (vae *AutoencoderKLFlux2) decodeTile(z *mlx.Array) *mlx.Array {
|
||||
// Post-quant conv
|
||||
if vae.PostQuantConv != nil {
|
||||
z = vae.PostQuantConv.Forward(z)
|
||||
}
|
||||
|
||||
// Decoder
|
||||
h := vae.DecoderConvIn.Forward(z)
|
||||
h = vae.DecoderMid.Forward(h)
|
||||
|
||||
for _, upBlock := range vae.DecoderUp {
|
||||
h = upBlock.Forward(h)
|
||||
}
|
||||
|
||||
h = vae.DecoderNormOut.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.DecoderConvOut.Forward(h)
|
||||
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.MulScalar(h, 0.5)
|
||||
h = mlx.AddScalar(h, 0.5)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// loadEncoderWeights loads the encoder components for image conditioning
|
||||
func (m *AutoencoderKLFlux2) loadEncoderWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
|
||||
var err error
|
||||
|
||||
// Load encoder conv_in
|
||||
m.EncoderConvIn = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.EncoderConvIn, weights, "encoder.conv_in"); err != nil {
|
||||
return fmt.Errorf("encoder.conv_in: %w", err)
|
||||
}
|
||||
|
||||
// Load encoder down blocks
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.EncoderDown = make([]*DownEncoderBlock2D, numBlocks)
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", i)
|
||||
hasDownsample := i < numBlocks-1
|
||||
m.EncoderDown[i], err = loadDownEncoderBlock2D(weights, prefix, cfg.LayersPerBlock, cfg.NormNumGroups, hasDownsample)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", prefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load encoder mid block
|
||||
m.EncoderMid, err = loadVAEMidBlock(weights, "encoder.mid_block", cfg.NormNumGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder.mid_block: %w", err)
|
||||
}
|
||||
|
||||
// Load encoder conv_norm_out and conv_out
|
||||
m.EncoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
|
||||
if err := safetensors.LoadModule(m.EncoderNormOut, weights, "encoder.conv_norm_out"); err != nil {
|
||||
return fmt.Errorf("encoder.conv_norm_out: %w", err)
|
||||
}
|
||||
|
||||
m.EncoderConvOut = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.EncoderConvOut, weights, "encoder.conv_out"); err != nil {
|
||||
return fmt.Errorf("encoder.conv_out: %w", err)
|
||||
}
|
||||
|
||||
// Load quant_conv (for encoding)
|
||||
if cfg.UseQuantConv {
|
||||
m.QuantConv = &Conv2D{Stride: 1, Padding: 0}
|
||||
if err := safetensors.LoadModule(m.QuantConv, weights, "quant_conv"); err != nil {
|
||||
return fmt.Errorf("quant_conv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadDownEncoderBlock2D loads a down encoder block.
|
||||
func loadDownEncoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasDownsample bool) (*DownEncoderBlock2D, error) {
|
||||
resnets := make([]*ResnetBlock2D, numLayers)
|
||||
for i := int32(0); i < numLayers; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resnets[i] = resnet
|
||||
}
|
||||
|
||||
var downsample *Conv2D
|
||||
if hasDownsample {
|
||||
downsample = &Conv2D{Stride: 2, Padding: 0}
|
||||
if err := safetensors.LoadModule(downsample, weights, prefix+".downsamplers.0.conv"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &DownEncoderBlock2D{
|
||||
ResnetBlocks: resnets,
|
||||
Downsample: downsample,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncodeImage encodes an image to normalized latents.
|
||||
// image: [B, 3, H, W] image tensor in [-1, 1]
|
||||
// Returns: [B, L, C*4] patchified normalized latents
|
||||
func (vae *AutoencoderKLFlux2) EncodeImage(image *mlx.Array) *mlx.Array {
|
||||
// Convert NCHW -> NHWC
|
||||
x := mlx.Transpose(image, 0, 2, 3, 1)
|
||||
|
||||
// Encoder
|
||||
h := vae.EncoderConvIn.Forward(x)
|
||||
|
||||
for _, downBlock := range vae.EncoderDown {
|
||||
h = downBlock.Forward(h)
|
||||
}
|
||||
|
||||
h = vae.EncoderMid.Forward(h)
|
||||
h = vae.EncoderNormOut.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.EncoderConvOut.Forward(h)
|
||||
|
||||
// Quant conv outputs [B, H, W, 2*latent_channels] (mean + logvar)
|
||||
if vae.QuantConv != nil {
|
||||
h = vae.QuantConv.Forward(h)
|
||||
}
|
||||
|
||||
// Take only the mean (first latent_channels) - deterministic encoding
|
||||
// h is [B, H, W, 64] -> take first 32 channels for mean
|
||||
shape := h.Shape()
|
||||
latentChannels := vae.Config.LatentChannels // 32
|
||||
h = mlx.Slice(h, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], latentChannels})
|
||||
|
||||
// Convert NHWC -> NCHW for patchifying
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
|
||||
// Patchify: [B, C, H, W] -> [B, L, C*4]
|
||||
h = vae.Patchify(h)
|
||||
|
||||
// Apply BatchNorm on patchified latents [B, L, 128]
|
||||
// The BatchNorm has 128 channels matching the patchified dimension
|
||||
h = vae.normalizePatchified(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// normalizePatchified applies batch normalization to patchified latents.
|
||||
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
|
||||
// Output: [B, L, 128] normalized
|
||||
func (vae *AutoencoderKLFlux2) normalizePatchified(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[2] // 128
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, C]
|
||||
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
|
||||
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
|
||||
|
||||
// Normalize: (x - mean) / sqrt(var + eps)
|
||||
xNorm := mlx.Sub(x, mean)
|
||||
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
|
||||
|
||||
// Scale and shift (only if affine=True)
|
||||
if vae.LatentBN.Weight != nil {
|
||||
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if vae.LatentBN.Bias != nil {
|
||||
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
@@ -1,390 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models.
|
||||
package qwen3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Config holds Qwen3 text encoder configuration
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
}
|
||||
|
||||
// Attention implements Qwen3 attention with QK norms
|
||||
type Attention struct {
|
||||
QProj nn.LinearLayer `weight:"q_proj"`
|
||||
KProj nn.LinearLayer `weight:"k_proj"`
|
||||
VProj nn.LinearLayer `weight:"v_proj"`
|
||||
OProj nn.LinearLayer `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
RopeTheta float32
|
||||
}
|
||||
|
||||
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
|
||||
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
freqsArr := make([]float32, half)
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
posArr := make([]float32, seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
posArr[i] = float32(i)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{seqLen})
|
||||
|
||||
posExpanded := mlx.Reshape(pos, seqLen, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
args := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
cosVals := mlx.Cos(args)
|
||||
sinVals := mlx.Sin(args)
|
||||
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
|
||||
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
|
||||
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
|
||||
|
||||
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
|
||||
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
|
||||
|
||||
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
|
||||
}
|
||||
|
||||
// Forward computes attention with causal masking and optional padding mask
|
||||
func (attn *Attention) Forward(x *mlx.Array, mask *mlx.Array, maskMode string) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
q := attn.QProj.Forward(x)
|
||||
k := attn.KProj.Forward(x)
|
||||
v := attn.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
|
||||
q = attn.QNorm.Forward(q, 1e-6)
|
||||
k = attn.KNorm.Forward(k, 1e-6)
|
||||
|
||||
q = applyRoPEQwen3(q, L, attn.RopeTheta)
|
||||
k = applyRoPEQwen3(k, L, attn.RopeTheta)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
k = repeatKV(k, repeats)
|
||||
v = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, attn.Scale, maskMode, mask, nil)
|
||||
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
out = attn.OProj.Forward(out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// MLP implements Qwen3 SwiGLU MLP
|
||||
type MLP struct {
|
||||
GateProj nn.LinearLayer `weight:"gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := m.GateProj.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := m.UpProj.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
return m.DownProj.Forward(h)
|
||||
}
|
||||
|
||||
// Block represents a single Qwen3 transformer block
|
||||
type Block struct {
|
||||
Attention *Attention `weight:"self_attn"`
|
||||
MLP *MLP `weight:"mlp"`
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the Qwen3 block
|
||||
func (qb *Block) Forward(x *mlx.Array, eps float32, mask *mlx.Array, maskMode string) *mlx.Array {
|
||||
h := qb.InputLayerNorm.Forward(x, eps)
|
||||
attnOut := qb.Attention.Forward(h, mask, maskMode)
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
h = qb.PostAttnLayerNorm.Forward(x, eps)
|
||||
mlpOut := qb.MLP.Forward(h)
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// TextEncoder is the full Qwen3 encoder
|
||||
type TextEncoder struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Block `weight:"model.layers"`
|
||||
FinalNorm *nn.RMSNorm `weight:"model.norm"`
|
||||
*Config
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from ollama blob storage.
|
||||
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
|
||||
fmt.Print(" Loading text encoder... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg Config
|
||||
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
m.Layers = make([]*Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *TextEncoder) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *TextEncoder) initComputedFields() {
|
||||
cfg := m.Config
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
for _, block := range m.Layers {
|
||||
// Attention
|
||||
block.Attention.NHeads = cfg.NumAttentionHeads
|
||||
block.Attention.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.Attention.HeadDim = cfg.HeadDim
|
||||
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.Attention.RopeTheta = cfg.RopeTheta
|
||||
block.Attention.QNorm.Eps = cfg.RMSNormEps
|
||||
block.Attention.KNorm.Eps = cfg.RMSNormEps
|
||||
// Block norms
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
}
|
||||
|
||||
// Forward encodes text tokens with provided attention mask (LxL) and mask mode.
|
||||
func (te *TextEncoder) Forward(tokens *mlx.Array, attnMask *mlx.Array, maskMode string) *mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
for _, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps, attnMask, maskMode)
|
||||
}
|
||||
|
||||
// Apply final RMS norm
|
||||
h = te.FinalNorm.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// ForwardWithLayerOutputs encodes text tokens and returns hidden states from specified layers.
|
||||
// This is used by Flux2 which needs embeddings from specific intermediate layers.
|
||||
func (te *TextEncoder) ForwardWithLayerOutputs(tokens *mlx.Array, layerIndices []int, attnMask *mlx.Array, maskMode string) []*mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
outputs := make([]*mlx.Array, len(layerIndices))
|
||||
layerSet := make(map[int]int)
|
||||
for i, idx := range layerIndices {
|
||||
layerSet[idx] = i
|
||||
}
|
||||
|
||||
for i, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps, attnMask, maskMode)
|
||||
if outIdx, ok := layerSet[i]; ok {
|
||||
outputs[outIdx] = h
|
||||
}
|
||||
}
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
// ApplyChatTemplate wraps prompt in Qwen3 chat format.
|
||||
// If think is true, adds the <think></think> block after the assistant tag
|
||||
// (matches tokenizer.apply_chat_template with enable_thinking=False in Python).
|
||||
func ApplyChatTemplate(prompt string, think bool) string {
|
||||
base := "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
|
||||
if think {
|
||||
return base + "<think>\n\n</think>\n\n"
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// EncodePrompt encodes a text prompt using the tokenizer and encoder.
|
||||
// If think is true, includes the <think></think> block in the chat template.
|
||||
func (te *TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int, think bool) (*mlx.Array, *mlx.Array) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt, think)
|
||||
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
maskData := make([]float32, maxLen)
|
||||
for i := 0; i < len(tokens); i++ {
|
||||
maskData[i] = 1.0
|
||||
}
|
||||
|
||||
// Get PAD token (different from EOS for Qwen3)
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
|
||||
paddedTokens := make([]int32, maxLen)
|
||||
copy(paddedTokens, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
paddedTokens[i] = padToken
|
||||
}
|
||||
|
||||
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
|
||||
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
|
||||
|
||||
// Build combined causal + PAD mask [L, L]
|
||||
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
|
||||
L := int32(maxLen)
|
||||
validLen := int32(len(tokens))
|
||||
combinedMaskData := make([]float32, L*L)
|
||||
negInf := float32(-1e9)
|
||||
for i := int32(0); i < L; i++ {
|
||||
for j := int32(0); j < L; j++ {
|
||||
idx := i*L + j
|
||||
if j <= i && j < validLen {
|
||||
combinedMaskData[idx] = 0
|
||||
} else {
|
||||
combinedMaskData[idx] = negInf
|
||||
}
|
||||
}
|
||||
}
|
||||
maskMat := mlx.NewArray(combinedMaskData, []int32{L, L})
|
||||
|
||||
embeddings := te.Forward(tokensArr, maskMat, "")
|
||||
|
||||
return embeddings, maskArr
|
||||
}
|
||||
|
||||
// EncodePromptWithLayers encodes a text prompt and returns embeddings from specified layers.
|
||||
// Used by Flux2 which concatenates embeddings from multiple intermediate layers.
|
||||
// If think is true, includes the <think></think> block in the chat template.
|
||||
// Returns embeddings and padded sequence length.
|
||||
func (te *TextEncoder) EncodePromptWithLayers(tok *tokenizer.Tokenizer, prompt string, maxLen int, layerIndices []int, think bool) (*mlx.Array, int32) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt, think)
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
// Pad to maxLen
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
padded := make([]int32, maxLen)
|
||||
copy(padded, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
padded[i] = padToken
|
||||
}
|
||||
tokensArr := mlx.NewArrayInt32(padded, []int32{1, int32(maxLen)})
|
||||
|
||||
// Build combined causal + PAD mask [L, L]
|
||||
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
|
||||
// This combines causal masking with PAD token masking
|
||||
L := int32(maxLen)
|
||||
validLen := int32(len(tokens))
|
||||
maskData := make([]float32, L*L)
|
||||
negInf := float32(-1e9)
|
||||
for i := int32(0); i < L; i++ {
|
||||
for j := int32(0); j < L; j++ {
|
||||
idx := i*L + j
|
||||
if j <= i && j < validLen {
|
||||
maskData[idx] = 0 // allowed: causal OK and not PAD
|
||||
} else {
|
||||
maskData[idx] = negInf // blocked: future or PAD
|
||||
}
|
||||
}
|
||||
}
|
||||
maskMat := mlx.NewArray(maskData, []int32{L, L})
|
||||
|
||||
layerOutputs := te.ForwardWithLayerOutputs(tokensArr, layerIndices, maskMat, "")
|
||||
|
||||
// Concatenate layer outputs along the hidden dimension
|
||||
// Each output is [B, L, hidden_dim], result is [B, L, num_layers * hidden_dim]
|
||||
embeddings := mlx.Concatenate(layerOutputs, 2)
|
||||
|
||||
// Return embeddings and padded length
|
||||
return embeddings, int32(maxLen)
|
||||
}
|
||||
@@ -17,13 +17,13 @@ import (
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
|
||||
// Layer caching (DeepCache/Learning-to-Cache speedup)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
@@ -31,6 +31,9 @@ type GenerateConfig struct {
|
||||
CacheLayers int // Number of shallow layers to cache (default: 25)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
@@ -114,7 +117,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
@@ -126,7 +129,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
|
||||
@@ -18,15 +18,18 @@ import (
|
||||
// GenerateConfig holds all options for image editing.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
||||
Width int32 // Output width (default: from input image)
|
||||
Height int32 // Output height (default: from input image)
|
||||
Steps int // Denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
||||
Width int32 // Output width (default: from input image)
|
||||
Height int32 // Output height (default: from input image)
|
||||
Steps int // Denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image-Edit diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
|
||||
@@ -3,17 +3,287 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen3"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Re-export types from shared qwen3 package for backwards compatibility
|
||||
type (
|
||||
Qwen3Config = qwen3.Config
|
||||
Qwen3Attention = qwen3.Attention
|
||||
Qwen3MLP = qwen3.MLP
|
||||
Qwen3Block = qwen3.Block
|
||||
Qwen3TextEncoder = qwen3.TextEncoder
|
||||
)
|
||||
// Qwen3Config holds Qwen3 text encoder configuration
|
||||
type Qwen3Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen3 attention with QK norms
|
||||
type Qwen3Attention struct {
|
||||
QProj nn.LinearLayer `weight:"q_proj"`
|
||||
KProj nn.LinearLayer `weight:"k_proj"`
|
||||
VProj nn.LinearLayer `weight:"v_proj"`
|
||||
OProj nn.LinearLayer `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
RopeTheta float32
|
||||
}
|
||||
|
||||
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
|
||||
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
freqsArr := make([]float32, half)
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
posArr := make([]float32, seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
posArr[i] = float32(i)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{seqLen})
|
||||
|
||||
posExpanded := mlx.Reshape(pos, seqLen, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
args := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
cosVals := mlx.Cos(args)
|
||||
sinVals := mlx.Sin(args)
|
||||
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
|
||||
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
|
||||
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
|
||||
|
||||
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
|
||||
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
|
||||
|
||||
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
|
||||
}
|
||||
|
||||
// Forward computes attention with causal masking
|
||||
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
q := attn.QProj.Forward(x)
|
||||
k := attn.KProj.Forward(x)
|
||||
v := attn.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
|
||||
q = attn.QNorm.Forward(q, 1e-6)
|
||||
k = attn.KNorm.Forward(k, 1e-6)
|
||||
|
||||
q = applyRoPEQwen3(q, L, attn.RopeTheta)
|
||||
k = applyRoPEQwen3(k, L, attn.RopeTheta)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
k = repeatKV(k, repeats)
|
||||
v = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
|
||||
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
out = attn.OProj.Forward(out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// Qwen3MLP implements Qwen3 SwiGLU MLP
|
||||
type Qwen3MLP struct {
|
||||
GateProj nn.LinearLayer `weight:"gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := m.GateProj.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := m.UpProj.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
return m.DownProj.Forward(h)
|
||||
}
|
||||
|
||||
// Qwen3Block represents a single Qwen3 transformer block
|
||||
type Qwen3Block struct {
|
||||
Attention *Qwen3Attention `weight:"self_attn"`
|
||||
MLP *Qwen3MLP `weight:"mlp"`
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the Qwen3 block
|
||||
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
h := qb.InputLayerNorm.Forward(x, eps)
|
||||
attnOut := qb.Attention.Forward(h)
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
h = qb.PostAttnLayerNorm.Forward(x, eps)
|
||||
mlpOut := qb.MLP.Forward(h)
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
|
||||
type Qwen3TextEncoder struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Qwen3Block `weight:"model.layers"`
|
||||
FinalNorm *nn.RMSNorm `weight:"model.norm"`
|
||||
*Qwen3Config
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from ollama blob storage.
|
||||
func (m *Qwen3TextEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading text encoder... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg Qwen3Config
|
||||
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Qwen3Config = &cfg
|
||||
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *Qwen3TextEncoder) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *Qwen3TextEncoder) initComputedFields() {
|
||||
cfg := m.Qwen3Config
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
for _, block := range m.Layers {
|
||||
// Attention
|
||||
block.Attention.NHeads = cfg.NumAttentionHeads
|
||||
block.Attention.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.Attention.HeadDim = cfg.HeadDim
|
||||
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.Attention.RopeTheta = cfg.RopeTheta
|
||||
block.Attention.QNorm.Eps = cfg.RMSNormEps
|
||||
block.Attention.KNorm.Eps = cfg.RMSNormEps
|
||||
// Block norms
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
for _, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps)
|
||||
}
|
||||
|
||||
// Apply final RMS norm
|
||||
h = te.FinalNorm.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// ApplyChatTemplate wraps prompt in Qwen3 chat format
|
||||
var ApplyChatTemplate = qwen3.ApplyChatTemplate
|
||||
func ApplyChatTemplate(prompt string) string {
|
||||
return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
|
||||
}
|
||||
|
||||
// EncodePrompt encodes a text prompt using the tokenizer and encoder
|
||||
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt)
|
||||
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
maskData := make([]float32, maxLen)
|
||||
for i := 0; i < len(tokens); i++ {
|
||||
maskData[i] = 1.0
|
||||
}
|
||||
|
||||
// Get PAD token (different from EOS for Qwen3)
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
|
||||
paddedTokens := make([]int32, maxLen)
|
||||
copy(paddedTokens, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
paddedTokens[i] = padToken
|
||||
}
|
||||
|
||||
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
|
||||
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
|
||||
|
||||
embeddings := te.Forward(tokensArr)
|
||||
|
||||
return embeddings, maskArr
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
)
|
||||
|
||||
// VAEConfig holds VAE decoder configuration
|
||||
@@ -637,9 +636,6 @@ type VAEDecoder struct {
|
||||
UpBlocks []*UpDecoderBlock2D
|
||||
ConvNormOut *GroupNormLayer
|
||||
ConvOut *Conv2D
|
||||
|
||||
// Tiling configuration (nil = no tiling)
|
||||
Tiling *vae.TilingConfig
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from ollama blob storage.
|
||||
@@ -734,60 +730,45 @@ func (m *VAEDecoder) loadWeights(weights safetensors.WeightSource, cfg *VAEConfi
|
||||
|
||||
// Decode decodes latents to images.
|
||||
// Input latents are in NCHW format, output is in NCHW format.
|
||||
// If Tiling is set, uses tiled decoding to reduce memory for large images.
|
||||
func (v *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
// Internally uses NHWC format (MLX native) for all operations.
|
||||
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
// Scale latents
|
||||
z := mlx.DivScalar(latents, v.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, v.Config.ShiftFactor)
|
||||
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
||||
// Convert NCHW -> NHWC for internal processing
|
||||
z = mlx.Transpose(z, 0, 2, 3, 1)
|
||||
|
||||
// Use tiled decoding if enabled
|
||||
if v.Tiling != nil {
|
||||
mlx.Eval(z)
|
||||
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
|
||||
}
|
||||
|
||||
// Direct decode
|
||||
h := v.decodeTile(z)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
// Convert NHWC -> NCHW for output
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
mlx.Eval(h)
|
||||
return h
|
||||
}
|
||||
|
||||
// decodeTile decodes a single latent tile to pixels.
|
||||
// Input: [B, H, W, C] latent tile in NHWC format (already scaled)
|
||||
// Output: [B, H*8, W*8, 3] pixel tile in NHWC format
|
||||
func (v *VAEDecoder) decodeTile(z *mlx.Array) *mlx.Array {
|
||||
h := v.ConvIn.Forward(z)
|
||||
h := vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
|
||||
prev := h
|
||||
h = v.MidBlock.Forward(h)
|
||||
h = vae.MidBlock.Forward(h)
|
||||
prev.Free()
|
||||
|
||||
for _, upBlock := range v.UpBlocks {
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
prev = h
|
||||
h = upBlock.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
prev = h
|
||||
h = v.ConvNormOut.Forward(h)
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
|
||||
prev.Free()
|
||||
|
||||
prev = h
|
||||
h = mlx.SiLU(h)
|
||||
h = v.ConvOut.Forward(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
mlx.Eval(h)
|
||||
prev.Free()
|
||||
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.MulScalar(h, 0.5)
|
||||
h = mlx.AddScalar(h, 0.5)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
|
||||
// Convert NHWC -> NCHW for output
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
mlx.Eval(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -12,20 +12,19 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// TeaCache options (timestep embedding aware caching)
|
||||
TeaCache bool // TeaCache is always enabled for faster inference
|
||||
@@ -35,6 +34,9 @@ type GenerateConfig struct {
|
||||
FusedQKV bool // Enable fused QKV projection (default: false)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Z-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelName string
|
||||
@@ -91,7 +93,7 @@ func (m *Model) Load(modelName string) error {
|
||||
|
||||
// Load text encoder
|
||||
m.TextEncoder = &Qwen3TextEncoder{}
|
||||
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
|
||||
if err := m.TextEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
@@ -137,7 +139,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
@@ -149,7 +151,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
@@ -177,16 +179,9 @@ func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*m
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements runner.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
@@ -227,9 +222,9 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
// Text encoding with padding to multiple of 32
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512, false)
|
||||
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512)
|
||||
if useCFG {
|
||||
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512, false)
|
||||
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512)
|
||||
}
|
||||
|
||||
// Pad both to same length (multiple of 32)
|
||||
@@ -453,11 +448,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
teaCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode - enable tiling for larger images to reduce memory
|
||||
// VAE attention is O(n²) on latent pixels, tiling helps significantly
|
||||
if latentH > 64 || latentW > 64 {
|
||||
m.VAEDecoder.Tiling = vae.DefaultTilingConfig()
|
||||
}
|
||||
// VAE decode
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
)
|
||||
|
||||
@@ -41,15 +40,10 @@ type Response struct {
|
||||
Total int `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// ImageModel is the interface for image generation models
|
||||
type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// Server holds the model and handles requests
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
model ImageModel
|
||||
model *zimage.Model
|
||||
modelName string
|
||||
}
|
||||
|
||||
@@ -86,25 +80,10 @@ func Execute(args []string) error {
|
||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
||||
}
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType := imagegen.DetectModelType(*modelName)
|
||||
slog.Info("detected model type", "type", modelType)
|
||||
|
||||
var model ImageModel
|
||||
switch modelType {
|
||||
case "Flux2KleinPipeline":
|
||||
m := &flux2.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
model = m
|
||||
default:
|
||||
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
|
||||
m := &zimage.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
model = m
|
||||
// Load model
|
||||
model := &zimage.Model{}
|
||||
if err := model.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
@@ -180,19 +159,26 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate image using the common interface
|
||||
// Generate image
|
||||
ctx := r.Context()
|
||||
enc := json.NewEncoder(w)
|
||||
|
||||
// Progress callback streams step updates
|
||||
progress := func(step, total int) {
|
||||
resp := Response{Step: step, Total: total}
|
||||
enc.Encode(resp)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: req.Seed,
|
||||
Progress: func(step, total int) {
|
||||
resp := Response{
|
||||
Step: step,
|
||||
Total: total,
|
||||
Done: false,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
|
||||
@@ -17,24 +17,6 @@ type WeightSource interface {
|
||||
GetTensor(name string) (*mlx.Array, error)
|
||||
ListTensors() []string
|
||||
HasTensor(name string) bool
|
||||
Quantization() string // Returns "FP4", "FP8", or ""
|
||||
}
|
||||
|
||||
// quantizationParams returns groupSize, bits, mode for a quantization type.
|
||||
// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
|
||||
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "FP4":
|
||||
return 32, 4, "affine"
|
||||
default:
|
||||
return 32, 8, "affine" // FP8 or unknown
|
||||
}
|
||||
}
|
||||
|
||||
// Transformer allows structs to transform weight arrays before assignment.
|
||||
// Implement this to apply operations like transpose during loading.
|
||||
type Transformer interface {
|
||||
Transform(field string, arr *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// LoadModule loads weights into a struct using reflection and struct tags.
|
||||
@@ -154,10 +136,6 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Transform before assigning if parent implements Transformer
|
||||
if t, ok := v.Addr().Interface().(Transformer); ok {
|
||||
arr = t.Transform(field.Name, arr)
|
||||
}
|
||||
fieldVal.Set(reflect.ValueOf(arr))
|
||||
continue
|
||||
}
|
||||
@@ -245,21 +223,19 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
groupSize, bits, mode := quantizationParams(weights.Quantization())
|
||||
|
||||
if mlx.MetalIsAvailable() {
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
GroupSize: 32,
|
||||
Bits: 8,
|
||||
Mode: "affine",
|
||||
}, nil
|
||||
}
|
||||
|
||||
dequantized := mlx.Dequantize(weight, scales, qbiases, groupSize, bits, mode)
|
||||
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
|
||||
return nn.NewLinear(dequantized, bias), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -298,11 +298,6 @@ func (mw *ModelWeights) HasTensor(name string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// Quantization returns empty string for directory-based weights (not quantized).
|
||||
func (mw *ModelWeights) Quantization() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// ReleaseAll releases all cached native file handles.
|
||||
func (mw *ModelWeights) ReleaseAll() {
|
||||
for path, native := range mw.nativeCache {
|
||||
|
||||
@@ -510,11 +510,7 @@ func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) {
|
||||
t.vocab.Merges[merge] = i
|
||||
}
|
||||
|
||||
// Add all added_tokens to vocabulary and special tokens map.
|
||||
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
|
||||
// they bypass BPE and get their own token ID. The "special" flag just indicates
|
||||
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
|
||||
// to treat all added_tokens as special to match HuggingFace behavior.
|
||||
// Add special tokens to vocabulary
|
||||
for _, tok := range raw.AddedTokens {
|
||||
if int(tok.ID) >= len(t.vocab.Values) {
|
||||
newValues := make([]string, tok.ID+1)
|
||||
@@ -522,7 +518,9 @@ func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) {
|
||||
t.vocab.Values = newValues
|
||||
}
|
||||
t.vocab.Values[tok.ID] = tok.Content
|
||||
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
|
||||
if tok.Special {
|
||||
t.specialTokens[tok.Content] = tok.ID
|
||||
}
|
||||
}
|
||||
|
||||
// Load special token configuration from companion files
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package vae provides shared utilities for VAE (Variational Autoencoder) operations.
|
||||
package vae
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TilingConfig holds configuration for tiled VAE decoding.
|
||||
// This is a general technique to reduce memory usage when decoding large latents.
|
||||
type TilingConfig struct {
|
||||
TileSize int32 // Tile size in latent space (e.g., 64 latent → 512 pixels for 8x VAE)
|
||||
Overlap int32 // Overlap in latent space (e.g., 16 latent = 25% of 64)
|
||||
}
|
||||
|
||||
// DefaultTilingConfig returns reasonable defaults matching diffusers.
|
||||
// tile_latent_min_size=64, tile_overlap_factor=0.25
|
||||
func DefaultTilingConfig() *TilingConfig {
|
||||
return &TilingConfig{
|
||||
TileSize: 64, // 64 latent pixels
|
||||
Overlap: 16, // 25% overlap
|
||||
}
|
||||
}
|
||||
|
||||
// decodedTile holds a decoded tile's pixel data and dimensions
|
||||
type decodedTile struct {
|
||||
data []float32
|
||||
height int32
|
||||
width int32
|
||||
}
|
||||
|
||||
// DecodeTiled decodes latents using tiled processing with overlap blending.
|
||||
// This reduces memory usage for large images by processing in overlapping tiles.
|
||||
//
|
||||
// Parameters:
|
||||
// - latents: [1, H, W, C] latent tensor in NHWC format
|
||||
// - cfg: tiling configuration (tile size and overlap)
|
||||
// - decoder: function to decode a single tile [1, H, W, C] -> [1, H*scale, W*scale, 3]
|
||||
//
|
||||
// Returns: [1, 3, H*scale, W*scale] decoded image in NCHW format
|
||||
func DecodeTiled(latents *mlx.Array, cfg *TilingConfig, decoder func(*mlx.Array) *mlx.Array) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
H := shape[1] // latent height
|
||||
W := shape[2] // latent width
|
||||
C := shape[3]
|
||||
|
||||
tileLatentSize := cfg.TileSize
|
||||
overlapLatent := cfg.Overlap
|
||||
|
||||
// If image is small enough, just decode normally
|
||||
if H <= tileLatentSize && W <= tileLatentSize {
|
||||
decoded := decoder(latents)
|
||||
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
|
||||
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
|
||||
decoded = mlx.Transpose(decoded, 0, 3, 1, 2) // NHWC -> NCHW
|
||||
return decoded
|
||||
}
|
||||
|
||||
// Calculate tiling parameters (matching diffusers)
|
||||
overlapSize := tileLatentSize - overlapLatent // stride in latent space
|
||||
|
||||
// Blend extent in pixel space (assumes 8x upscale, adjust if needed)
|
||||
// For other scale factors, this could be made configurable
|
||||
tileSampleSize := tileLatentSize * 8 // tile size in pixels after 8x upscale
|
||||
blendExtent := overlapLatent * 8 // blend region in pixels
|
||||
rowLimit := tileSampleSize - blendExtent // non-overlapping region per tile
|
||||
|
||||
// Phase 1: Decode all tiles and store in 2D grid
|
||||
var rows [][]decodedTile
|
||||
|
||||
for i := int32(0); i < H; i += overlapSize {
|
||||
var row []decodedTile
|
||||
for j := int32(0); j < W; j += overlapSize {
|
||||
// Extract tile (may be smaller at edges)
|
||||
i2 := min(i+tileLatentSize, H)
|
||||
j2 := min(j+tileLatentSize, W)
|
||||
|
||||
tile := mlx.Slice(latents, []int32{0, i, j, 0}, []int32{1, i2, j2, C})
|
||||
decoded := decoder(tile)
|
||||
decoded = mlx.AsType(decoded, mlx.DtypeFloat32)
|
||||
mlx.Eval(decoded)
|
||||
|
||||
decodedShape := decoded.Shape()
|
||||
tileH := decodedShape[1]
|
||||
tileW := decodedShape[2]
|
||||
tileData := decoded.Data()
|
||||
decoded.Free()
|
||||
|
||||
row = append(row, decodedTile{data: tileData, height: tileH, width: tileW})
|
||||
}
|
||||
rows = append(rows, row)
|
||||
}
|
||||
|
||||
// Phase 2: Blend adjacent tiles (modifies in place)
|
||||
for i := range rows {
|
||||
for j := range rows[i] {
|
||||
tile := &rows[i][j]
|
||||
|
||||
// Blend with tile above
|
||||
if i > 0 {
|
||||
above := &rows[i-1][j]
|
||||
blendV(above, tile, blendExtent)
|
||||
}
|
||||
|
||||
// Blend with tile to the left
|
||||
if j > 0 {
|
||||
left := &rows[i][j-1]
|
||||
blendH(left, tile, blendExtent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Calculate crop dimensions for each tile
|
||||
colWidths := make([]int32, len(rows[0]))
|
||||
for j := range rows[0] {
|
||||
keepW := rowLimit
|
||||
if int32(j+1)*overlapSize >= W {
|
||||
keepW = rows[0][j].width
|
||||
}
|
||||
colWidths[j] = keepW
|
||||
}
|
||||
|
||||
rowHeights := make([]int32, len(rows))
|
||||
for i := range rows {
|
||||
keepH := rowLimit
|
||||
if int32(i+1)*overlapSize >= H {
|
||||
keepH = rows[i][0].height
|
||||
}
|
||||
rowHeights[i] = keepH
|
||||
}
|
||||
|
||||
// Calculate total dimensions
|
||||
var totalW, totalH int32
|
||||
for _, w := range colWidths {
|
||||
totalW += w
|
||||
}
|
||||
for _, h := range rowHeights {
|
||||
totalH += h
|
||||
}
|
||||
|
||||
// Phase 4: Assemble final image by interleaving tiles row-by-row
|
||||
finalData := make([]float32, totalH*totalW*3)
|
||||
|
||||
dstY := int32(0)
|
||||
for i, row := range rows {
|
||||
keepH := rowHeights[i]
|
||||
|
||||
for y := int32(0); y < keepH; y++ {
|
||||
dstX := int32(0)
|
||||
for j, tile := range row {
|
||||
keepW := colWidths[j]
|
||||
|
||||
for x := int32(0); x < keepW; x++ {
|
||||
for c := int32(0); c < 3; c++ {
|
||||
srcIdx := (y*tile.width + x) * 3 + c
|
||||
dstIdx := ((dstY + y) * totalW + (dstX + x)) * 3 + c
|
||||
finalData[dstIdx] = tile.data[srcIdx]
|
||||
}
|
||||
}
|
||||
dstX += keepW
|
||||
}
|
||||
}
|
||||
dstY += keepH
|
||||
}
|
||||
|
||||
// Create mlx array [1, H, W, 3] then transpose to NCHW [1, 3, H, W]
|
||||
result := mlx.NewArray(finalData, []int32{1, totalH, totalW, 3})
|
||||
result = mlx.Transpose(result, 0, 3, 1, 2)
|
||||
result = mlx.ClipScalar(result, 0.0, 1.0, true, true)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// blendV blends the bottom of 'above' tile into top of 'current' tile (vertical blend)
|
||||
// Matches diffusers blend_v formula
|
||||
func blendV(above, current *decodedTile, blendExtent int32) {
|
||||
blend := min(blendExtent, min(above.height, current.height))
|
||||
if blend <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
w := min(above.width, current.width)
|
||||
for y := int32(0); y < blend; y++ {
|
||||
alpha := float32(y) / float32(blend)
|
||||
for x := int32(0); x < w; x++ {
|
||||
for c := int32(0); c < 3; c++ {
|
||||
aboveIdx := ((above.height - blend + y) * above.width + x) * 3 + c
|
||||
currIdx := (y * current.width + x) * 3 + c
|
||||
current.data[currIdx] = above.data[aboveIdx]*(1-alpha) + current.data[currIdx]*alpha
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// blendH blends the right of 'left' tile into left of 'current' tile (horizontal blend)
|
||||
// Matches diffusers blend_h formula
|
||||
func blendH(left, current *decodedTile, blendExtent int32) {
|
||||
blend := min(blendExtent, min(left.width, current.width))
|
||||
if blend <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
h := min(left.height, current.height)
|
||||
for y := int32(0); y < h; y++ {
|
||||
for x := int32(0); x < blend; x++ {
|
||||
alpha := float32(x) / float32(blend)
|
||||
for c := int32(0); c < 3; c++ {
|
||||
leftIdx := (y * left.width + (left.width - blend + x)) * 3 + c
|
||||
currIdx := (y * current.width + x) * 3 + c
|
||||
current.data[currIdx] = left.data[leftIdx]*(1-alpha) + current.data[currIdx]*alpha
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -106,21 +106,6 @@ func (mw *ManifestWeights) HasTensor(name string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// Quantization returns the model's quantization type from model_index.json.
|
||||
// Returns empty string if not quantized or unknown.
|
||||
func (mw *ManifestWeights) Quantization() string {
|
||||
if mw.manifest == nil {
|
||||
return ""
|
||||
}
|
||||
var index struct {
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil {
|
||||
return ""
|
||||
}
|
||||
return index.Quantization
|
||||
}
|
||||
|
||||
// ReleaseAll frees all native handles and clears the tensor cache.
|
||||
func (mw *ManifestWeights) ReleaseAll() {
|
||||
for _, sf := range mw.nativeCache {
|
||||
|
||||
Reference in New Issue
Block a user