Compare commits

..

1 Commits

Author SHA1 Message Date
Michael Yang
5f9edb46fc move tokenizers to separate package 2026-01-21 10:42:46 -08:00
112 changed files with 8700 additions and 6974 deletions

View File

@@ -749,7 +749,7 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModelInfo map[string]any `json:"model_info"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`

View File

@@ -899,11 +899,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
for _, arg := range args {
// Unload the model if it's running before deletion
if err := loadOrUnloadModel(cmd, &runOptions{
Model: arg,
Model: args[0],
KeepAlive: &api.Duration{Duration: 0},
}); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", arg)
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
}
}

View File

@@ -311,10 +311,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &deepseekocr{}
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
case "Lfm2ForCausalLM":
conv = &lfm2Model{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

View File

@@ -1,150 +0,0 @@
package convert
import (
"cmp"
"fmt"
"log/slog"
"regexp"
"strconv"
"github.com/ollama/ollama/fs/ggml"
)
type glm4MoeLiteModel struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
KVLoraRank uint32 `json:"kv_lora_rank"`
QLoraRank uint32 `json:"q_lora_rank"`
VHeadDim uint32 `json:"v_head_dim"`
ExpertCount uint32 `json:"n_routed_experts"`
ExpertSharedCount uint32 `json:"n_shared_experts"`
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
ExpertWeightsNorm bool `json:"norm_topk_prob"`
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
}
func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "glm4moelite"
kv["general.type"] = "model"
kv["glm4moelite.block_count"] = p.HiddenLayers
numHeads := p.NumAttentionHeads
numKVHeads := p.NumKeyValueHeads
kv["glm4moelite.attention.head_count"] = numHeads
kv["glm4moelite.attention.head_count_kv"] = numKVHeads
kv["glm4moelite.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
kv["glm4moelite.attention.kv_lora_rank"] = p.KVLoraRank
kv["glm4moelite.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["glm4moelite.attention.q_lora_rank"] = p.QLoraRank
kv["glm4moelite.attention.value_length"] = p.VHeadDim
kv["glm4moelite.context_length"] = p.MaxPositionEmbeddings
kv["glm4moelite.embedding_length"] = p.HiddenSize
kv["glm4moelite.expert_count"] = p.ExpertCount
kv["glm4moelite.expert_feed_forward_length"] = p.ExpertIntermediateSize
kv["glm4moelite.expert_shared_count"] = p.ExpertSharedCount
kv["glm4moelite.expert_gating_func"] = uint32(2)
kv["glm4moelite.expert_used_count"] = p.ExpertUsedCount
kv["glm4moelite.expert_weights_norm"] = p.ExpertWeightsNorm
kv["glm4moelite.expert_weights_scale"] = p.ExpertWeightsScale
kv["glm4moelite.feed_forward_length"] = p.IntermediateSize
kv["glm4moelite.leading_dense_block_count"] = p.LeadingDenseBlockCount
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
kv["tokenizer.ggml.pre"] = "glm4"
return kv
}
func (p *glm4MoeLiteModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
"self_attn.kv_b_proj", "attn_kv_b",
"self_attn.q_a_proj", "attn_q_a",
"self_attn.q_a_layernorm", "attn_q_a_norm",
"self_attn.q_b_proj", "attn_q_b",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
"mlp.gate", "ffn_gate_inp",
}
}
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
skipLayer := func(n string, minValue uint32) bool {
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(n)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return uint32(blkNum) >= minValue
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
// skip any additional layers (such as the Multi-Token Prediction layer)
if skipLayer(t.Name(), p.HiddenLayers) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}

View File

@@ -1,100 +0,0 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type lfm2Model struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
NormEps float32 `json:"norm_eps"`
ConvLCache uint32 `json:"conv_L_cache"`
LayerTypes []string `json:"layer_types"`
TieEmbedding bool `json:"tie_embedding"`
}
var _ ModelConverter = (*lfm2Model)(nil)
func (p *lfm2Model) KV(t *Tokenizer) KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "lfm2"
kv["lfm2.vocab_size"] = p.VocabSize
kv["lfm2.block_count"] = p.NumHiddenLayers
kv["lfm2.embedding_length"] = p.HiddenSize
kv["lfm2.feed_forward_length"] = p.IntermediateSize
kv["lfm2.context_length"] = p.MaxPositionEmbeddings
// Build per-layer KV head count array based on layer_types
// (0 = shortconv layer, non-zero = attention layer with that many KV heads)
kvHeadCounts := make([]uint32, p.NumHiddenLayers)
for i := range p.NumHiddenLayers {
if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
kvHeadCounts[i] = p.NumKeyValueHeads
}
}
kv["lfm2.attention.head_count"] = p.NumAttentionHeads
kv["lfm2.attention.head_count_kv"] = kvHeadCounts
kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
kv["lfm2.rope.freq_base"] = p.RopeTheta
kv["lfm2.shortconv.l_cache"] = p.ConvLCache
return kv
}
func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
shape := t.Shape()
// Squeeze conv weights: [D, 1, K] -> [D, K]
if strings.HasSuffix(t.Name(), "shortconv.conv.weight") {
if len(shape) == 3 && shape[1] == 1 {
shape = []uint64{shape[0], shape[2]}
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: slices.Clone(shape),
WriterTo: t,
})
}
return out
}
func (p *lfm2Model) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.embedding_norm", "output_norm",
"model.layers", "blk",
"operator_norm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.out_proj", "attn_output",
"self_attn.q_layernorm", "attn_q_norm",
"self_attn.k_layernorm", "attn_k_norm",
"conv.conv", "shortconv.conv",
"conv.in_proj", "shortconv.in_proj",
"conv.out_proj", "shortconv.out_proj",
"feed_forward.w1", "ffn_gate",
"feed_forward.w2", "ffn_down",
"feed_forward.w3", "ffn_up",
"ffn_norm", "ffn_norm",
}
}

View File

@@ -40,7 +40,6 @@ const (
func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
strings.HasSuffix(t.name, ".bias") ||
strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||

View File

@@ -2,7 +2,7 @@
title: Claude Code
---
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
@@ -26,16 +26,6 @@ irm https://claude.ai/install.ps1 | iex
## Usage with Ollama
Configure Claude Code to use Ollama:
```shell
ollama config claude
```
This will prompt you to select a model and automatically configure Claude Code to use Ollama.
<Accordion title="Manual Configuration">
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
@@ -57,9 +47,7 @@ Or run with environment variables inline:
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
```
</Accordion>
<Note>Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.</Note>
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
## Connecting to ollama.com
@@ -87,4 +75,4 @@ claude --model glm-4.7:cloud
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks

View File

@@ -2,31 +2,22 @@
title: Codex
---
Codex is OpenAI's agentic coding tool for the command line.
## Install
Install the [Codex CLI](https://developers.openai.com/codex/cli/):
```shell
```
npm install -g @openai/codex
```
## Usage with Ollama
Configure Codex to use Ollama:
```shell
ollama config codex
```
This will prompt you to select a model and automatically configure Codex to use Ollama.
<Accordion title="Manual Configuration">
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
To use `codex` with Ollama, use the `--oss` flag:
```shell
```
codex --oss
```
@@ -34,22 +25,20 @@ codex --oss
By default, codex will use the local `gpt-oss:20b` model. However, you can specify a different model with the `-m` flag:
```shell
```
codex --oss -m gpt-oss:120b
```
### Cloud Models
```shell
```
codex --oss -m gpt-oss:120b-cloud
```
</Accordion>
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
## Connecting to ollama.com
Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
To use ollama.com directly, edit your `~/.codex/config.toml` file to point to ollama.com.

View File

@@ -2,7 +2,6 @@
title: Droid
---
Droid is Factory's agentic coding tool for the command line.
## Install
@@ -12,80 +11,66 @@ Install the [Droid CLI](https://factory.ai/):
curl -fsSL https://app.factory.ai/cli | sh
```
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
Configure Droid to use Ollama:
```shell
ollama config droid
```
This will prompt you to select models and automatically configure Droid to use Ollama.
<Accordion title="Manual Configuration">
Add a local configuration block to `~/.factory/settings.json`:
Add a local configuration block to `~/.factory/config.json`:
```json
{
"customModels": [
"custom_models": [
{
"model_display_name": "qwen3-coder [Ollama]",
"model": "qwen3-coder",
"displayName": "qwen3-coder [Ollama]",
"baseUrl": "http://localhost:11434/v1",
"apiKey": "ollama",
"base_url": "http://localhost:11434/v1/",
"api_key": "not-needed",
"provider": "generic-chat-completion-api",
"maxOutputTokens": 32000
"max_tokens": 32000
}
]
}
```
Adjust `maxOutputTokens` based on your model's context length (the automated setup detects this automatically).
### Cloud Models
## Cloud Models
`qwen3-coder:480b-cloud` is the recommended model for use with Droid.
Add the cloud configuration block to `~/.factory/settings.json`:
Add the cloud configuration block to `~/.factory/config.json`:
```json
{
"customModels": [
"custom_models": [
{
"model_display_name": "qwen3-coder [Ollama Cloud]",
"model": "qwen3-coder:480b-cloud",
"displayName": "qwen3-coder:480b-cloud [Ollama]",
"baseUrl": "http://localhost:11434/v1",
"apiKey": "ollama",
"base_url": "http://localhost:11434/v1/",
"api_key": "not-needed",
"provider": "generic-chat-completion-api",
"maxOutputTokens": 128000
"max_tokens": 128000
}
]
}
```
</Accordion>
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
2. Add the cloud configuration block to `~/.factory/settings.json`:
2. Add the cloud configuration block to `~/.factory/config.json`:
```json
{
"customModels": [
"custom_models": [
{
"model_display_name": "qwen3-coder [Ollama Cloud]",
"model": "qwen3-coder:480b",
"displayName": "qwen3-coder:480b [Ollama Cloud]",
"baseUrl": "https://ollama.com/v1",
"apiKey": "OLLAMA_API_KEY",
"base_url": "https://ollama.com/v1/",
"api_key": "OLLAMA_API_KEY",
"provider": "generic-chat-completion-api",
"maxOutputTokens": 128000
"max_tokens": 128000
}
]
}
```
Run `droid` in a new terminal to load the new settings.
Run `droid` in a new terminal to load the new settings.

View File

@@ -1,63 +0,0 @@
---
title: OpenCode
---
OpenCode is an agentic coding tool for the terminal.
## Install
Install [OpenCode](https://opencode.ai):
```shell
curl -fsSL https://opencode.ai/install | bash
```
## Usage with Ollama
Configure OpenCode to use Ollama:
```shell
ollama config opencode
```
This will prompt you to select models and automatically configure OpenCode to use Ollama.
<Accordion title="Manual Configuration">
Add the Ollama provider to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama (local)",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"qwen3-coder": {
"name": "qwen3-coder [Ollama]"
}
}
}
}
}
```
</Accordion>
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
## Recommended Models
### Cloud models
- `qwen3-coder:480b` - Large coding model
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks

View File

@@ -269,8 +269,6 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
"lfm2",
}, kv.Architecture())
}
@@ -858,9 +856,7 @@ func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"bert",
"gemma3",
"glm4moelite",
"gptoss", "gpt-oss",
"lfm2",
"mistral3",
"olmo3",
"qwen3", "qwen3moe",

View File

@@ -1,148 +0,0 @@
//go:build integration
package integration
import (
"context"
"encoding/base64"
"fmt"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestImageGeneration(t *testing.T) {
skipUnderMinVRAM(t, 8)
type testCase struct {
imageGenModel string
visionModel string
prompt string
expectedWords []string
}
testCases := []testCase{
{
imageGenModel: "jmorgan/z-image-turbo",
visionModel: "llama3.2-vision",
prompt: "A cartoon style llama flying like a superhero through the air with clouds in the background",
expectedWords: []string{"llama", "flying", "cartoon", "cloud", "sky", "superhero", "air", "animal", "camelid"},
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s->%s", tc.imageGenModel, tc.visionModel), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Pull both models
if err := PullIfMissing(ctx, client, tc.imageGenModel); err != nil {
t.Fatalf("failed to pull image gen model: %v", err)
}
if err := PullIfMissing(ctx, client, tc.visionModel); err != nil {
t.Fatalf("failed to pull vision model: %v", err)
}
// Generate the image
t.Logf("Generating image with prompt: %s", tc.prompt)
imageBase64, err := generateImage(ctx, client, tc.imageGenModel, tc.prompt)
if err != nil {
if strings.Contains(err.Error(), "image generation not available") {
t.Skip("Target system does not support image generation")
} else if strings.Contains(err.Error(), "executable file not found in") { // Windows pattern, not yet supported
t.Skip("Windows does not support image generation yet")
} else if strings.Contains(err.Error(), "CUDA driver version is insufficient") {
t.Skip("Driver is too old")
} else if strings.Contains(err.Error(), "insufficient memory for image generation") {
t.Skip("insufficient memory for image generation")
} else if strings.Contains(err.Error(), "error while loading shared libraries: libcuda.so.1") { // AMD GPU or CPU
t.Skip("CUDA GPU is not available")
} else if strings.Contains(err.Error(), "ollama-mlx: no such file or directory") {
// most likely linux arm - not supported yet
t.Skip("unsupported architecture")
}
t.Fatalf("failed to generate image: %v", err)
}
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
t.Fatalf("failed to decode image: %v", err)
}
t.Logf("Generated image: %d bytes", len(imageData))
// Preload vision model and check GPU loading
err = client.Generate(ctx, &api.GenerateRequest{Model: tc.visionModel}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load vision model: %v", err)
}
// Use vision model to describe the image
chatReq := api.ChatRequest{
Model: tc.visionModel,
Messages: []api.Message{
{
Role: "user",
Content: "Describe this image in detail. What is shown? What style is it? What is the main subject doing?",
Images: []api.ImageData{imageData},
},
},
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}
// Verify the vision model's response contains expected keywords
response := DoChat(ctx, t, client, chatReq, tc.expectedWords, 240*time.Second, 30*time.Second)
if response != nil {
t.Logf("Vision model response: %s", response.Content)
// Additional detailed check for keywords
content := strings.ToLower(response.Content)
foundWords := []string{}
missingWords := []string{}
for _, word := range tc.expectedWords {
if strings.Contains(content, word) {
foundWords = append(foundWords, word)
} else {
missingWords = append(missingWords, word)
}
}
t.Logf("Found keywords: %v", foundWords)
if len(missingWords) > 0 {
t.Logf("Missing keywords (at least one was found so test passed): %v", missingWords)
}
}
})
}
}
// generateImage calls the Ollama API to generate an image and returns the base64 image data
func generateImage(ctx context.Context, client *api.Client, model, prompt string) (string, error) {
var imageBase64 string
err := client.Generate(ctx, &api.GenerateRequest{
Model: model,
Prompt: prompt,
}, func(resp api.GenerateResponse) error {
if resp.Image != "" {
imageBase64 = resp.Image
}
return nil
})
if err != nil {
return "", fmt.Errorf("failed to generate image: %w", err)
}
if imageBase64 == "" {
return "", fmt.Errorf("no image data in response")
}
return imageBase64, nil
}

View File

@@ -38,7 +38,6 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"lfm2.5-thinking",
"ministral-3",
"qwen3-coder:30b",
"gpt-oss:20b",
@@ -144,7 +143,6 @@ var (
"granite3.3",
"hermes3",
"internlm2",
"lfm2.5-thinking",
"llama-guard3",
"llama-pro",
"llama2-chinese",
@@ -265,7 +263,6 @@ var (
"snowflake-arctic-embed2",
}
libraryToolsModels = []string{
"lfm2.5-thinking",
"qwen3-vl",
"gpt-oss:20b",
"gpt-oss:120b",

View File

@@ -34,6 +34,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizers"
)
type filteredEnv []string
@@ -115,7 +116,7 @@ type llamaServer struct {
type ollamaServer struct {
llmServer
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
tokenizer tokenizers.Tokenizer // textProcessor handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -141,11 +142,11 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var llamaModel *llama.Model
var textProcessor model.TextProcessor
var tokenizer tokenizers.Tokenizer
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath)
tokenizer, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
@@ -154,7 +155,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if textProcessor == nil {
if tokenizer == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
@@ -210,7 +211,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if textProcessor == nil {
if tokenizer == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
@@ -260,7 +261,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
textProcessor != nil,
tokenizer != nil,
modelPath,
gpuLibs,
status,
@@ -309,8 +310,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}()
if textProcessor != nil {
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
if tokenizer != nil {
return &ollamaServer{llmServer: s, tokenizer: tokenizer}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -1772,7 +1773,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.textProcessor.Encode(content, false)
tokens, err := s.tokenizer.Encode(content, false)
if err != nil {
return nil, err
}
@@ -1807,7 +1808,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
content, err := s.tokenizer.Decode(toks)
if err != nil {
return "", err
}

View File

@@ -1,95 +0,0 @@
package manifest
import (
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
var ErrInvalidDigestFormat = errors.New("invalid digest format")
func Path() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// PathForName returns the path to the manifest file for a specific model name.
func PathForName(n model.Name) (string, error) {
if !n.IsValid() {
return "", os.ErrNotExist
}
manifests, err := Path()
if err != nil {
return "", err
}
return filepath.Join(manifests, n.Filepath()), nil
}
func BlobsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "blobs", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
// PruneDirectory removes empty directories recursively.
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}

View File

@@ -162,7 +162,6 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
SSMConv(ctx Context, kernel Tensor) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor

View File

@@ -1641,13 +1641,6 @@ func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2,
return tt
}
func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t),
}
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,

409
model/ignore_test.go Normal file
View File

File diff suppressed because one or more lines are too long

View File

@@ -23,6 +23,7 @@ import (
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
var (
@@ -119,7 +120,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
func NewTextProcessor(s string) (tokenizers.Tokenizer, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
@@ -136,7 +137,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
tp, ok := m.(TextProcessor)
tp, ok := m.(tokenizers.Tokenizer)
if !ok {
return nil, ErrUnsupportedTokenizer
}

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Model struct {
model.Base
model.TextProcessor
tokenizers.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
}
func New(c fs.Config) (model.Model, error) {
vocab := &model.Vocabulary{
vocab := &tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
},
}
var processor model.TextProcessor
var tokenizer tokenizers.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(vocab, true)
tokenizer = tokenizers.NewWordPiece(vocab, true)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Tokenizer: tokenizer,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Options struct {
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizers.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizers.NewBytePairEncoding(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Model struct {
model.Base
model.TextProcessor
tokenizers.Tokenizer
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
@@ -134,8 +135,8 @@ func init() {
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizers.NewBytePairEncoding(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Options struct {
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.SentencePiece
tokenizers.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -43,8 +44,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizers.NewSentencePiece(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type embedModel struct {
model.Base
model.SentencePiece
tokenizers.Tokenizer
*TextModel
poolingType pooling.Type
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizers.NewSentencePiece(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Model struct {
model.Base
model.TextProcessor
tokenizers.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
vocabulary := tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
),
}
var processor model.TextProcessor
var tokenizer tokenizers.Tokenizer
switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(&vocabulary)
tokenizer = tokenizers.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
processor = model.NewSentencePiece(&vocabulary)
tokenizer = tokenizers.NewSentencePiece(&vocabulary)
}
m := Model{
TextProcessor: processor,
Tokenizer: tokenizer,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),

View File

@@ -6,11 +6,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Model struct {
model.Base
model.SentencePiece
tokenizers.Tokenizer
*TextModel
}
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Tokenizer: tokenizers.NewSentencePiece(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),

View File

@@ -1,304 +0,0 @@
package glm4moelite
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
numExpertsUsed int
numExperts int
normTopKProb bool
routedScalingFactor float32
kvLoraRank,
qkNopeHeadDim,
qkRopeHeadDim,
kqNopeHeadDim,
qkHeadDim int
qLoraRank int
vHeadDim int
hiddenSize,
numHeads,
numKVHeads int
eps,
ropeBase float32
kqScale float64
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1.0)
}
type Attention struct {
Q *nn.Linear `gguf:"attn_q"`
QA *nn.Linear `gguf:"attn_q_a"`
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
QB *nn.Linear `gguf:"attn_q_b"`
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
KVB *nn.Linear `gguf:"attn_kv_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
if opts.qLoraRank == 0 {
query = attn.Q.Forward(ctx, hiddenStates)
} else {
query = attn.QA.Forward(ctx, hiddenStates)
query = attn.QANorm.Forward(ctx, query, opts.eps)
query = attn.QB.Forward(ctx, query)
}
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
kRot := compressedKV.View(ctx,
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
compressedKV.Stride(1), 1,
compressedKV.Stride(1), compressedKV.Dim(1),
)
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
}
type MLP interface {
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
}
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
SharedExpert *dense `gguf:",suf:_shexp"`
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
}
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = hiddenStates.SILU(ctx, upStates)
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
experts = experts.Mul(ctx, topKWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates
}
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
if moe.ExpProbsBias != nil {
scores = scores.Add(ctx, moe.ExpProbsBias)
}
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
return topKIndices
}
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
residuals := hiddenStates
routerLogits := moe.Router.Forward(ctx, hiddenStates)
scores := routerLogits.Sigmoid(ctx)
topKIndices := moe.topKIndices(ctx, scores, opts)
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
if opts.normTopKProb {
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
}
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
return hiddenStates
}
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *Attention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP MLP
}
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count"))
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
for i := range layers {
if i < firstDenseLayerIndex {
layers[i].MLP = &dense{}
} else {
layers[i].MLP = &sparse{}
}
}
keyLength := int(c.Uint("attention.key_length"))
valueLength := int(c.Uint("attention.value_length"))
kqScale := 1.0 / math.Sqrt(float64(keyLength))
var pre []string
switch c.String("tokenizer.ggml.pre") {
case "glm4":
pre = []string{
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
}
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
pre...,
),
Layers: layers,
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("expert_weights_norm", true),
qLoraRank: int(c.Uint("attention.q_lora_rank")),
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
qkHeadDim: keyLength,
vHeadDim: valueLength,
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
routedScalingFactor: c.Float("expert_weights_scale"),
kqScale: kqScale,
},
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("glm4moelite", New)
}

View File

@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Transformer struct {
model.Base
model.BytePairEncoding
tokenizers.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizers.NewBytePairEncoding(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -1,410 +0,0 @@
package lfm2
import (
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var _ kvcache.Cache = (*HybridCache)(nil)
// HybridCache stores:
// - a standard causal KV cache for attention layers
// - a per-sequence recurrent conv state for shortconv layers
//
// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
type HybridCache struct {
kv *kvcache.Causal
backend ml.Backend
dtype ml.DType
maxSequences int
hiddenSize int
dConv int
// slot mapping for recurrent state
slotForSeq map[int]int
refCount []int
freeSlots []int
// per-layer conv state buffers (allocated lazily)
convCtxs map[int]ml.Context
convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
// current forward batch (derived in StartForward)
curSeqs []int
curSlots []int
curSlotsInput ml.Tensor
curSeqTokens int
// track if EnsureWritable has been called for this forward pass
writableEnsured bool
// track any error from EnsureWritable to propagate later
writableError error
}
func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
return &HybridCache{
kv: kvcache.NewCausalCache(shift),
hiddenSize: hiddenSize,
dConv: dConv,
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
}
}
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.backend = backend
c.dtype = dtype
c.maxSequences = maxSequences
// initialize slot allocator
c.refCount = make([]int, maxSequences)
c.freeSlots = c.freeSlots[:0]
for i := maxSequences - 1; i >= 0; i-- {
c.freeSlots = append(c.freeSlots, i)
}
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
func (c *HybridCache) Close() {
for _, ctx := range c.convCtxs {
ctx.Close()
}
c.kv.Close()
}
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
c.kv.SetConfig(config)
}
func (c *HybridCache) SetLayer(layer int) {
c.kv.SetLayer(layer)
}
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.kv.Get(ctx)
}
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.kv.Put(ctx, key, value)
}
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
return err
}
// Derive equal-length sequence layout for shortconv.
// LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
seqCounts := make(map[int]int)
c.curSeqs = c.curSeqs[:0]
for _, s := range batch.Sequences {
if _, ok := seqCounts[s]; !ok {
c.curSeqs = append(c.curSeqs, s)
}
seqCounts[s]++
}
if len(c.curSeqs) == 0 {
return nil
}
nTokens := len(batch.Sequences)
nSeqs := len(c.curSeqs)
want := nTokens / nSeqs
for _, s := range c.curSeqs {
if seqCounts[s] != want {
return kvcache.ErrNotSupported
}
}
c.curSeqTokens = want
// When reserving memory for estimation, use fake slot assignments
// without modifying permanent state (slotForSeq, refCount)
if reserve {
c.curSlots = c.curSlots[:0]
slots := make([]int32, nSeqs)
for i := range nSeqs {
c.curSlots = append(c.curSlots, i)
slots[i] = int32(i)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
// Ensure slots exist for sequences in this batch
c.curSlots = c.curSlots[:0]
var newSlots []int // track newly allocated slots that need zeroing
for _, s := range c.curSeqs {
slot, ok := c.slotForSeq[s]
if !ok {
var err error
slot, err = c.allocSlot()
if err != nil {
return err
}
c.slotForSeq[s] = slot
c.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
c.curSlots = append(c.curSlots, slot)
}
// Zero conv state for newly allocated slots to clear stale data from previous sequences
if len(newSlots) > 0 {
c.zeroConvSlots(ctx, newSlots)
}
// Create a tensor for the current slots
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
// Reset writable state for new forward pass
c.writableEnsured = false
c.writableError = nil
return nil
}
func (c *HybridCache) allocSlot() (int, error) {
if len(c.freeSlots) == 0 {
return 0, kvcache.ErrKvCacheFull
}
slot := c.freeSlots[len(c.freeSlots)-1]
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
return slot, nil
}
func (c *HybridCache) freeSlot(slot int) {
// Bounds check before freeing
if slot >= 0 && slot < c.maxSequences {
c.freeSlots = append(c.freeSlots, slot)
}
}
// zeroConvSlots zeros the conv state for the given slots across all layers.
// This must be called when recycling slots to prevent stale state from affecting new sequences.
func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
if len(slots) == 0 || len(c.convStates) == 0 {
return
}
// Use input context for creating tensors
inputCtx := ctx.Input()
// Create slot indices tensor
slotIndices := make([]int32, len(slots))
for i, s := range slots {
slotIndices[i] = int32(s)
}
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
// Create zero tensor for the slots (SetRows requires F32 source)
zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
// Zero each layer's conv state for these slots
for _, buf := range c.convStates {
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
}
}
// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
// Returns an error if slot allocation fails.
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
for i, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
continue
}
if c.refCount[slot] <= 1 {
continue
}
newSlot, err := c.allocSlot()
if err != nil {
return err
}
c.refCount[slot]--
c.refCount[newSlot] = 1
c.slotForSeq[seq] = newSlot
c.curSlots[i] = newSlot
// Copy existing conv state for all initialized layers
for _, buf := range c.convStates {
// buf: [dConv*hiddenSize, maxSlots]
src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
// SetRows requires F32 source
srcF32 := src.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
}
}
// Rebuild current slots tensor
slots := make([]int32, len(c.curSlots))
for i, v := range c.curSlots {
slots[i] = int32(v)
}
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
return nil
}
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
// KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
// For shortconv state we implement copy-on-write: dst shares the same slot as src.
// On the first write to dst, EnsureWritable will create a private slot.
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
// Bounds check before decrementing
if dstSlot >= 0 && dstSlot < len(c.refCount) {
c.refCount[dstSlot]--
if c.refCount[dstSlot] <= 0 {
c.refCount[dstSlot] = 0
c.freeSlot(dstSlot)
}
}
delete(c.slotForSeq, dstSeq)
}
srcSlot, ok := c.slotForSeq[srcSeq]
if !ok {
// src may not have a slot yet; dst will allocate on demand
return
}
// Bounds check before incrementing
if srcSlot >= 0 && srcSlot < len(c.refCount) {
c.slotForSeq[dstSeq] = srcSlot
c.refCount[srcSlot]++
}
}
func (c *HybridCache) CanResume(seq int, pos int32) bool {
return c.kv.CanResume(seq, pos)
}
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
return err
}
// For recurrent state, any removal invalidates the state because
// the state at position N depends on all previous positions.
// Drop the slot mapping so it resets on next use.
slot, ok := c.slotForSeq[seq]
if !ok {
return nil
}
// Bounds check
if slot < 0 || slot >= len(c.refCount) {
delete(c.slotForSeq, seq)
return nil
}
c.refCount[slot]--
if c.refCount[slot] <= 0 {
c.refCount[slot] = 0
c.freeSlot(slot)
}
delete(c.slotForSeq, seq)
return nil
}
func (c *HybridCache) slotsTensor() ml.Tensor {
return c.curSlotsInput
}
func (c *HybridCache) seqTokens() int {
return c.curSeqTokens
}
func (c *HybridCache) numSeqs() int {
return len(c.curSeqs)
}
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
if buf, ok := c.convStates[layer]; ok {
return buf
}
if _, ok := c.convCtxs[layer]; !ok {
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
}
buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
c.convStates[layer] = buf
return buf
}
// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
// Returns an error if copy-on-write allocation fails.
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
if !c.writableEnsured {
needsWritable := false
for _, seq := range c.curSeqs {
slot, ok := c.slotForSeq[seq]
if !ok {
continue
}
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
needsWritable = true
break
}
}
if needsWritable {
if err := c.EnsureWritable(ctx); err != nil {
c.writableError = err
}
}
c.writableEnsured = true
}
if c.writableError != nil {
return nil, c.writableError
}
buf := c.convBuffer(ctx, layer)
cur := buf.Rows(ctx, c.slotsTensor())
return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
}
// UpdateConvState writes a new conv state for current batch sequences.
// newState must have shape [dConv, hiddenSize, nSeqs].
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
buf := c.convBuffer(ctx, layer)
src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
// SetRows requires F32 source
srcF32 := src.Cast(ctx, ml.DTypeF32)
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
}
// IsSupportedForBatch returns true if the current batch layout supports shortconv.
func (c *HybridCache) IsSupportedForBatch() bool {
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
}
// Seqs returns the ordered unique sequences for the current forward pass.
func (c *HybridCache) Seqs() []int {
return slices.Clone(c.curSeqs)
}

View File

@@ -1,444 +0,0 @@
package lfm2
import (
"testing"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
// TestHybridCache tests verify the slot management logic of HybridCache.
// These tests focus on the recurrent state slot allocation, reference counting,
// and copy-on-write semantics without requiring a full ML backend.
// createSlotOnlyCache creates a HybridCache with only the slot management
// fields initialized. Used to test slot logic in isolation.
func createSlotOnlyCache(maxSequences int) *HybridCache {
return &HybridCache{
hiddenSize: 256,
dConv: 3,
maxSequences: maxSequences,
refCount: make([]int, maxSequences),
freeSlots: initFreeSlots(maxSequences),
slotForSeq: make(map[int]int),
convCtxs: make(map[int]ml.Context),
convStates: make(map[int]ml.Tensor),
}
}
func initFreeSlots(n int) []int {
slots := make([]int, 0, n)
for i := n - 1; i >= 0; i-- {
slots = append(slots, i)
}
return slots
}
func TestHybridCache_SlotAllocation(t *testing.T) {
cache := createSlotOnlyCache(4)
// Verify initial state
if len(cache.freeSlots) != 4 {
t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
}
// Allocate all slots
for range 4 {
slot, err := cache.allocSlot()
if err != nil {
t.Fatalf("allocSlot failed: %v", err)
}
cache.refCount[slot] = 1
}
// Should be full now
if len(cache.freeSlots) != 0 {
t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
}
// Trying to allocate another should fail
_, err := cache.allocSlot()
if err != kvcache.ErrKvCacheFull {
t.Errorf("expected ErrKvCacheFull, got %v", err)
}
}
func TestHybridCache_SlotReuse(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate a slot
slot1, _ := cache.allocSlot()
cache.refCount[slot1] = 1
// Free it
cache.refCount[slot1] = 0
cache.freeSlot(slot1)
// Allocate again - should get the same slot back (LIFO)
slot2, _ := cache.allocSlot()
if slot2 != slot1 {
t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
}
}
func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Simulate sharing slot with seq 2 (copy-on-write style)
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Should share the same slot
if cache.slotForSeq[2] != slot1 {
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
}
// Ref count should be 2
if cache.refCount[slot1] != 2 {
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Share with seq 2
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Unshare seq 2
cache.refCount[slot1]--
delete(cache.slotForSeq, 2)
// Ref count should be back to 1
if cache.refCount[slot1] != 1 {
t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
}
// Seq 2 should no longer have a slot
if _, ok := cache.slotForSeq[2]; ok {
t.Error("seq 2 should not have a slot after unshare")
}
}
func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
cache := createSlotOnlyCache(4)
initialFreeSlots := len(cache.freeSlots)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Free the slot when refCount drops to 0
cache.refCount[slot1]--
if cache.refCount[slot1] <= 0 {
cache.refCount[slot1] = 0
cache.freeSlot(slot1)
}
delete(cache.slotForSeq, 1)
// Slot should be freed
if len(cache.freeSlots) != initialFreeSlots {
t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
}
// Ref count should be 0
if cache.refCount[slot1] != 0 {
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_SlotOverwrite(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slots for seq 1 and seq 2
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
slot2, _ := cache.allocSlot()
cache.slotForSeq[2] = slot2
cache.refCount[slot2] = 1
initialFreeSlots := len(cache.freeSlots)
// Simulate overwriting seq 2's slot with slot1 (sharing)
// First free the old slot
cache.refCount[slot2]--
if cache.refCount[slot2] <= 0 {
cache.refCount[slot2] = 0
cache.freeSlot(slot2)
}
// Then share slot1
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Seq 2 should now share slot1
if cache.slotForSeq[2] != slot1 {
t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
}
// Old slot2 should be freed
if len(cache.freeSlots) != initialFreeSlots+1 {
t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
}
}
func TestHybridCache_BoundsChecking(t *testing.T) {
cache := createSlotOnlyCache(4)
// Test freeing invalid slot (should not panic)
cache.freeSlot(-1)
cache.freeSlot(100) // out of bounds
// freeSlot does bounds checking, so invalid slots should be ignored
if len(cache.freeSlots) != 4 {
t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
}
}
func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
cache := createSlotOnlyCache(8)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Fork to seq 2, 3, 4 (all share slot1)
for _, seq := range []int{2, 3, 4} {
cache.slotForSeq[seq] = slot1
cache.refCount[slot1]++
}
// Ref count should be 4
if cache.refCount[slot1] != 4 {
t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
}
// Remove seq 2, 3
for _, seq := range []int{2, 3} {
delete(cache.slotForSeq, seq)
cache.refCount[slot1]--
}
if cache.refCount[slot1] != 2 {
t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
}
// Slot should still be allocated (not in free list)
found := false
for _, s := range cache.freeSlots {
if s == slot1 {
found = true
break
}
}
if found {
t.Error("slot1 should not be in free list yet")
}
// Remove remaining sequences
for _, seq := range []int{1, 4} {
delete(cache.slotForSeq, seq)
cache.refCount[slot1]--
}
if cache.refCount[slot1] != 0 {
t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_ChainedSharing(t *testing.T) {
cache := createSlotOnlyCache(8)
// Create seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Share 1 -> 2
cache.slotForSeq[2] = slot1
cache.refCount[slot1]++
// Share 2 -> 3 (should still share slot1)
cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
cache.refCount[slot1]++
// All should share slot1
if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
t.Error("all sequences should share slot1")
}
if cache.refCount[slot1] != 3 {
t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
}
}
func TestHybridCache_CacheParameters(t *testing.T) {
cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
if cache.hiddenSize != 512 {
t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
}
if cache.dConv != 5 {
t.Errorf("expected dConv 5, got %d", cache.dConv)
}
}
func TestHybridCache_NumSeqs(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially no sequences
if cache.numSeqs() != 0 {
t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
}
// Manually set up current batch state
cache.curSeqs = []int{1, 2, 3}
if cache.numSeqs() != 3 {
t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
}
}
func TestHybridCache_SeqTokens(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially 0
if cache.seqTokens() != 0 {
t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
}
// Manually set up current batch state
cache.curSeqTokens = 16
if cache.seqTokens() != 16 {
t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
}
}
// Test that Seqs returns a clone of curSeqs
func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
cache := createSlotOnlyCache(4)
cache.curSeqs = []int{1, 2, 3}
seqs := cache.Seqs()
// Modify returned slice
seqs[0] = 999
// Original should be unchanged
if cache.curSeqs[0] != 1 {
t.Error("Seqs should return a clone, not the original slice")
}
}
func TestHybridCache_IsSupportedForBatch(t *testing.T) {
cache := createSlotOnlyCache(4)
// Initially not supported (no batch set up)
if cache.IsSupportedForBatch() {
t.Error("expected IsSupportedForBatch to be false initially")
}
// Set up a valid batch
cache.curSeqTokens = 1
cache.curSeqs = []int{1}
if !cache.IsSupportedForBatch() {
t.Error("expected IsSupportedForBatch to be true with valid batch")
}
}
func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
cache := createSlotOnlyCache(4)
// zeroConvSlots should handle empty slots without panicking
cache.zeroConvSlots(nil, nil)
cache.zeroConvSlots(nil, []int{})
// zeroConvSlots should handle empty convStates without panicking
cache.zeroConvSlots(nil, []int{0, 1, 2})
}
func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
cache := createSlotOnlyCache(4)
// Allocate slot for seq 1
slot1, _ := cache.allocSlot()
cache.slotForSeq[1] = slot1
cache.refCount[slot1] = 1
// Free the slot (simulating sequence removal)
cache.refCount[slot1]--
cache.freeSlot(slot1)
delete(cache.slotForSeq, 1)
// Verify slot is in free list
if len(cache.freeSlots) != 4 {
t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
}
// Allocate for new seq 2 - should get recycled slot
slot2, _ := cache.allocSlot()
if slot2 != slot1 {
t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
}
// This recycled slot would need zeroing in the real implementation
// The actual zeroing is tested via integration tests since it requires ML context
}
func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
cache := createSlotOnlyCache(4)
// Simulate the slot allocation flow from StartForward
// When a sequence doesn't have a slot, it gets allocated and tracked as "new"
newSlots := []int{}
// Seq 1 doesn't have a slot - allocate and track
seq := 1
if _, ok := cache.slotForSeq[seq]; !ok {
slot, err := cache.allocSlot()
if err != nil {
t.Fatalf("allocSlot failed: %v", err)
}
cache.slotForSeq[seq] = slot
cache.refCount[slot] = 1
newSlots = append(newSlots, slot)
}
// Verify newSlots contains the allocated slot
if len(newSlots) != 1 {
t.Errorf("expected 1 new slot, got %d", len(newSlots))
}
// Seq 1 already has a slot - should NOT be tracked as new
newSlots2 := []int{}
if _, ok := cache.slotForSeq[seq]; !ok {
slot, _ := cache.allocSlot()
cache.slotForSeq[seq] = slot
cache.refCount[slot] = 1
newSlots2 = append(newSlots2, slot)
}
// Verify no new slots for existing sequence
if len(newSlots2) != 0 {
t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
}
}

View File

@@ -1,253 +0,0 @@
package lfm2
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
ropeType string
originalContextLength int
// per-layer head counts (LFM2 alternates attention and recurrent layers)
numHeadsByLayer []int
numKVHeadsByLayer []int
}
func (o Options) headDimValue() int {
// Head dim is shared across layers; fall back to first attention layer head count.
for _, h := range o.numHeadsByLayer {
if h > 0 {
return cmp.Or(o.headDim, o.hiddenSize/h)
}
}
return cmp.Or(o.headDim, o.hiddenSize)
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
opts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
)
}
headCount := 1
for _, h := range o.numHeadsByLayer {
if h > 0 {
headCount = h
break
}
}
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/headCount), o.ropeBase, 1./o.ropeScale, opts...)
}
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
func New(c fs.Config) (model.Model, error) {
if c.Uint("expert_count") > 0 {
return nil, model.ErrUnsupportedModel
}
if c.String("tokenizer.ggml.model") != "gpt2" {
return nil, model.ErrUnsupportedTokenizer
}
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
var pretokenizers []string
switch c.String("tokenizer.ggml.pre") {
case "default":
// use default BPE pretokenizer
default:
// llama-bpe style (default for LFM2)
pretokenizers = []string{
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
}
}
m := Model{
TextProcessor: model.NewBytePairEncoding(&vocabulary, pretokenizers...),
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeType: c.String("rope.scaling.type"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
},
}
type headCounts interface {
HeadCount() []uint64
HeadCountKV() []uint64
}
hc, ok := c.(headCounts)
if !ok {
return nil, model.ErrUnsupportedModel
}
headCount := hc.HeadCount()
headCountKV := hc.HeadCountKV()
m.numHeadsByLayer = make([]int, len(m.Layers))
m.numKVHeadsByLayer = make([]int, len(m.Layers))
for i := range m.Layers {
m.numHeadsByLayer[i] = int(headCount[i])
m.numKVHeadsByLayer[i] = int(headCountKV[i])
if m.numKVHeadsByLayer[i] == 0 {
m.Layers[i].Operator = &ShortConv{}
} else {
m.Layers[i].Operator = &Attention{}
}
}
lCache := int(c.Uint("shortconv.l_cache"))
dConv := max(0, lCache-1)
m.Cache = NewHybridCache(m.Shift, m.hiddenSize, dConv)
return &m, nil
}
type Operator interface {
Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
}
func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
headDim := opts.headDimValue()
numHeads := opts.numHeadsByLayer[layer]
numKVHeads := opts.numKVHeadsByLayer[layer]
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, headDim, numHeads, batchSize)
key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Operator Operator
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.Operator.Forward(ctx, hiddenState, positions, cache, layer, opts)
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, outputs, m.Cache.(*HybridCache), &m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("lfm2", New)
}

View File

@@ -1,50 +0,0 @@
package lfm2
import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type shortConvKernel struct {
Weight ml.Tensor `gguf:"weight"`
}
// ShortConv implements the LFM2 short-convolution block (GGML_OP_SSM_CONV) with a recurrent
// state stored in the HybridCache.
type ShortConv struct {
Conv *shortConvKernel `gguf:"shortconv.conv"`
InProj *nn.Linear `gguf:"shortconv.in_proj"`
OutProj *nn.Linear `gguf:"shortconv.out_proj"`
}
func (sc *ShortConv) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
nSeqs := cache.numSeqs()
seqTokens := cache.seqTokens()
hiddenSize := hiddenStates.Dim(0)
if nSeqs <= 0 || seqTokens <= 0 || hiddenStates.Dim(1) != nSeqs*seqTokens {
panic("lfm2: unsupported batch layout for shortconv")
}
bcx := sc.InProj.Forward(ctx, hiddenStates).Reshape(ctx, 3*hiddenSize, seqTokens, nSeqs)
elementSize := bcx.Stride(0)
b := bcx.View(ctx, 0*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
c := bcx.View(ctx, 1*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
x := bcx.View(ctx, 2*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
bx := b.Mul(ctx, x).Permute(ctx, 1, 0, 2, 3)
state, err := cache.ConvState(ctx, layer)
if err != nil {
panic("lfm2: failed to get conv state: " + err.Error())
}
sx := state.Concat(ctx, bx, 0)
convOut := sx.SSMConv(ctx, sc.Conv.Weight)
y := c.Mul(ctx, convOut)
dConv := sx.Dim(0) - seqTokens
cache.UpdateConvState(ctx, layer, sx.Slice(ctx, 0, sx.Dim(0)-dConv, sx.Dim(0), 1))
return sc.OutProj.Forward(ctx, y.Reshape(ctx, hiddenSize, seqTokens*nSeqs))
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Options struct {
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
model.TextProcessor
tokenizers.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedModel
}
var processor model.TextProcessor
vocabulary := model.Vocabulary{
var processor tokenizers.Tokenizer
vocabulary := tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
processor = tokenizers.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = model.NewSentencePiece(&vocabulary)
processor = tokenizers.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizers.Tokenizer
ImageProcessor
*VisionModel `gguf:"v"`
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizers.NewBytePairEncoding(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizers.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -28,12 +29,12 @@ type Model struct {
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
var _ model.TextProcessor = (*Model)(nil)
var _ tokenizers.Tokenizer = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizers.NewBytePairEncoding(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizers.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -32,8 +33,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizers.NewBytePairEncoding(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -7,9 +7,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/glm4moelite"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/lfm2"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"

View File

@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Model struct {
model.Base
model.TextProcessor
tokenizers.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -178,8 +179,8 @@ func New(c fs.Config) (model.Model, error) {
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
tokenizer := tokenizers.NewWordPiece(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -219,8 +220,8 @@ func New(c fs.Config) (model.Model, error) {
}
return &Model{
TextProcessor: processor,
Layers: layers,
Tokenizer: tokenizer,
Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
const (
@@ -33,7 +34,7 @@ type Options struct {
type Model struct {
model.Base
model.TextProcessor
tokenizers.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -44,7 +45,7 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
vocabulary := tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -58,14 +59,14 @@ func New(c fs.Config) (model.Model, error) {
),
}
processor := model.NewBytePairEncoding(
tokenizer := tokenizers.NewBytePairEncoding(
&vocabulary,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Tokenizer: tokenizer,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Options struct {
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
type Model struct {
model.Base
model.BytePairEncoding
tokenizers.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizers.NewBytePairEncoding(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Model struct {
model.Base
model.BytePairEncoding
tokenizers.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizers.NewBytePairEncoding(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type embedModel struct {
model.Base
model.BytePairEncoding
tokenizers.Tokenizer
*Model
poolingType pooling.Type
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
layers[i].MLP = &dense{}
}
m := embedModel{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizers.NewBytePairEncoding(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Options struct {
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
model.BytePairEncoding
tokenizers.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
BytePairEncoding: tokenizers.NewBytePairEncoding(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizers"
)
type Model struct {
model.Base
model.TextProcessor
tokenizers.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Tokenizer: tokenizers.NewBytePairEncoding(
&tokenizers.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),

View File

@@ -1,410 +0,0 @@
package parsers
import (
"context"
"encoding/xml"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type glm46ParserState int
const (
glm46ParserState_LookingForThinkingOpen glm46ParserState = iota
glm46ParserState_ThinkingStartedEatingWhitespace
glm46ParserState_CollectingThinking
glm46ParserState_ThinkingDoneEatingWhitespace
glm46ParserState_CollectingContent
glm46ParserState_ToolStartedEatingWhitespace
glm46ParserState_CollectingToolContent
)
const (
glm46ThinkingOpenTag = "<think>"
glm46ThinkingCloseTag = "</think>"
glm46ToolOpenTag = "<tool_call>"
glm46ToolCloseTag = "</tool_call>"
)
type GLM46Parser struct {
state glm46ParserState
buffer strings.Builder
tools []api.Tool
}
func (p *GLM46Parser) HasToolSupport() bool {
return true
}
func (p *GLM46Parser) HasThinkingSupport() bool {
return true
}
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
return tools
}
type glm46Event interface {
isGLM46Event()
}
type glm46EventContent struct {
content string
}
func (glm46EventContent) isGLM46Event() {}
type glm46EventRawToolCall struct {
raw string
}
func (glm46EventRawToolCall) isGLM46Event() {}
type glm46EventThinkingContent struct {
content string
}
func (glm46EventThinkingContent) isGLM46Event() {}
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case glm46EventRawToolCall:
toolCall, err := parseGLM46ToolCall(event, p.tools)
if err != nil {
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case glm46EventThinkingContent:
thinkingSb.WriteString(event.content)
case glm46EventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content
// events, we naively append them together here.
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *GLM46Parser) parseEvents() []glm46Event {
var all []glm46Event
keepLooping := true
for keepLooping {
var events []glm46Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4.6 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
// and transitions to the next state. Returns (nil, false) if only whitespace remains
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
func (p *GLM46Parser) eatLeadingWhitespaceAndTransitionTo(nextState glm46ParserState) ([]glm46Event, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false // Still only whitespace, keep waiting for more input
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true // Successfully transitioned
}
// glm46SplitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
// the content after (optionally trimmed of leading whitespace), and updates the buffer
func glm46SplitAtTag(p *GLM46Parser, tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after
}
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
var events []glm46Event
switch p.state {
case glm46ParserState_LookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, glm46ThinkingOpenTag) {
// Found <think> opening tag
after := strings.TrimPrefix(trimmed, glm46ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = glm46ParserState_ThinkingStartedEatingWhitespace
} else {
p.state = glm46ParserState_CollectingThinking
}
return events, true
} else if strings.HasPrefix(glm46ThinkingOpenTag, trimmed) {
// Partial opening tag seen, keep accumulating
return events, false
} else if trimmed == "" {
// Only whitespace, keep accumulating
return events, false
} else {
// No thinking tag found, skip to content collection
p.state = glm46ParserState_CollectingContent
// Don't trim - we want to keep the original content
return events, true
}
case glm46ParserState_ThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingThinking)
case glm46ParserState_CollectingThinking:
acc := p.buffer.String()
if strings.Contains(acc, glm46ThinkingCloseTag) {
thinking, remaining := glm46SplitAtTag(p, glm46ThinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, glm46EventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = glm46ParserState_ThinkingDoneEatingWhitespace
} else {
p.state = glm46ParserState_CollectingContent
}
return events, true
} else if overlapLen := overlap(acc, glm46ThinkingCloseTag); overlapLen > 0 {
// Partial closing tag - withhold it along with any trailing whitespace before it
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
} else {
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
}
case glm46ParserState_ThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingContent)
case glm46ParserState_CollectingContent:
if strings.Contains(p.buffer.String(), glm46ToolOpenTag) {
before, after := glm46SplitAtTag(p, glm46ToolOpenTag, true)
if len(before) > 0 {
events = append(events, glm46EventContent{content: before})
}
if after == "" {
p.state = glm46ParserState_ToolStartedEatingWhitespace
} else {
p.state = glm46ParserState_CollectingToolContent
}
return events, true
} else if overlapLen := overlap(p.buffer.String(), glm46ToolOpenTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
}
case glm46ParserState_ToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(glm46ParserState_CollectingToolContent)
case glm46ParserState_CollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, glm46ToolCloseTag) {
toolContent, _ := glm46SplitAtTag(p, glm46ToolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("glm46 tool call closing tag found but no content before it")
}
events = append(events, glm46EventRawToolCall{raw: toolContent})
p.state = glm46ParserState_CollectingContent
return events, true
} else {
// Keep accumulating - tool calls are not streamed
// We just wait for the closing tag
return events, false
}
default:
panic("unreachable")
}
}
// GLMToolCallXML represents the structure of a GLM-4.6 tool call for XML parsing
type GLMToolCallXML struct {
XMLName xml.Name `xml:"tool_call"`
Content string `xml:",chardata"` // Function name (text nodes between tags)
Keys []string `xml:"arg_key"` // All arg_key elements in document order
Values []string `xml:"arg_value"` // All arg_value elements in document order
}
// escapeGLM46Content escapes XML entities in text content while preserving arg_key/arg_value tags
func escapeGLM46Content(s string) string {
var result strings.Builder
inTag := false
for i := range len(s) {
ch := s[i]
if ch == '<' {
// Check if this is a known tag
if strings.HasPrefix(s[i:], "<arg_key>") ||
strings.HasPrefix(s[i:], "</arg_key>") ||
strings.HasPrefix(s[i:], "<arg_value>") ||
strings.HasPrefix(s[i:], "</arg_value>") {
inTag = true
}
}
if inTag {
result.WriteByte(ch)
if ch == '>' {
inTag = false
}
} else {
// Escape special characters in text content
switch ch {
case '&':
result.WriteString("&amp;")
case '<':
result.WriteString("&lt;")
case '>':
result.WriteString("&gt;")
default:
result.WriteByte(ch)
}
}
}
return result.String()
}
func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
// Escape any unescaped entities in text content
// We need to escape text between tags, but not the tags themselves
escaped := escapeGLM46Content(raw.raw)
// Wrap the content in a root element to make it valid XML
xmlString := "<tool_call>" + escaped + "</tool_call>"
// Parse XML into struct
var parsed GLMToolCallXML
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
}
// Extract and trim function name
functionName := strings.TrimSpace(parsed.Content)
if functionName == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
// Verify keys and values are paired correctly
if len(parsed.Keys) != len(parsed.Values) {
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionName {
matchedTool = &tools[i]
break
}
}
// Build arguments map by pairing keys and values
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: functionName,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for i := range parsed.Keys {
key := strings.TrimSpace(parsed.Keys[i])
value := parsed.Values[i] // Don't trim here - parseValue handles it
// Look up parameter type
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
// Handle anyOf by collecting all types from the union
if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf {
paramType = append(paramType, anyOfProp.Type...)
}
} else {
paramType = prop.Type
}
}
}
// Parse value with type coercion
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
}
return toolCall, nil
}

View File

@@ -1,862 +0,0 @@
package parsers
import (
"encoding/xml"
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestGLM46ParserStreaming(t *testing.T) {
type step struct {
input string
wantEvents []glm46Event
}
cases := []struct {
desc string
steps []step
only bool
}{
{
desc: "leading whitespace before think tag",
steps: []step{
{
input: " \n\t ",
wantEvents: []glm46Event{},
},
{
input: "<think>thinking</think>",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking"}},
},
},
},
{
desc: "think tag with whitespace inside",
steps: []step{
{
input: "<think> \n thinking content \n </think>regular content",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking content"},
glm46EventContent{content: "regular content"},
},
},
},
},
{
desc: "tool call with leading whitespace after opening tag",
steps: []step{
{
input: "<think></think><tool_call> \n test \n </tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "test"},
},
},
},
},
{
desc: "simple thinking then content",
steps: []step{
{
input: "<think>I am thinking</think>Now I respond",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "I am thinking"},
glm46EventContent{content: "Now I respond"},
},
},
},
},
{
desc: "streamed thinking content",
steps: []step{
{
input: "<think>hello",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "hello"}},
},
{
input: " world",
wantEvents: []glm46Event{glm46EventThinkingContent{content: " world"}},
},
{
input: "</think>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "content before tool call",
steps: []step{
{
input: "<think>Let me call a tool</think>here is text<tool_call>",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "Let me call a tool"},
glm46EventContent{content: "here is text"},
},
},
{
input: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>\n</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "function_name\n<arg_key>param</arg_key>\n<arg_value>value</arg_value>"},
},
},
},
},
{
desc: "tool call with content after",
steps: []step{
{
input: "<think>thinking</think><tool_call>test</tool_call>after tool",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "after tool"},
},
},
},
},
{
desc: "trailing whitespace between content and tool call is trimmed",
steps: []step{
{
input: "<think>thinking</think>content\n \t <tool_call>test</tool_call>",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventContent{content: "content"},
glm46EventRawToolCall{raw: "test"},
},
},
},
},
{
desc: "trailing whitespace between tool call and content is trimmed",
steps: []step{
{
input: "<think>think</think><tool_call>test</tool_call>\n\t after",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "think"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "split thinking close tag",
steps: []step{
{
input: "<think>thinking content</th",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "thinking content"}},
},
{
input: "ink>after",
wantEvents: []glm46Event{
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "split thinking open tag",
steps: []step{
{
input: " <thi",
wantEvents: []glm46Event{},
},
{
input: "nk>content</think>",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
},
},
},
{
desc: "split tool open tag",
steps: []step{
{
input: "<think>think</think>content<tool",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "think"}, glm46EventContent{content: "content"}},
},
{
input: "_call>inside",
wantEvents: []glm46Event{},
},
{
input: "</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "inside"},
},
},
},
},
{
desc: "partial thinking close tag fakeout",
steps: []step{
{
input: "<think>content</th",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "content"}},
},
{
input: "ought more",
wantEvents: []glm46Event{glm46EventThinkingContent{content: "</thought more"}},
},
},
},
{
desc: "partial thinking open tag fakeout",
steps: []step{
{
input: " <thi",
wantEvents: []glm46Event{},
},
{
input: "nking is fun",
wantEvents: []glm46Event{
glm46EventContent{content: " <thinking is fun"},
},
},
},
},
{
desc: "partial tool open tag fakeout",
steps: []step{
{
input: "<think></think>content\n<tool",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
{
input: " fakeout",
wantEvents: []glm46Event{
glm46EventContent{content: "\n<tool fakeout"},
},
},
},
},
{
desc: "partial tool close tag fakeout",
steps: []step{
{
input: "<think></think><tool_call>content</tool",
wantEvents: []glm46Event{},
},
{
input: " fakeout",
wantEvents: []glm46Event{},
},
{
input: "</tool_call>",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "content</tool fakeout"},
},
},
},
},
{
desc: "empty thinking tag",
steps: []step{
{
input: "<think></think>content here",
wantEvents: []glm46Event{
glm46EventContent{content: "content here"},
},
},
},
},
{
desc: "multiple tool calls in sequence",
steps: []step{
{
input: "<think>think</think><tool_call>first</tool_call>between<tool_call>second</tool_call>end",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "think"},
glm46EventRawToolCall{raw: "first"},
glm46EventContent{content: "between"},
glm46EventRawToolCall{raw: "second"},
glm46EventContent{content: "end"},
},
},
},
},
{
desc: "no thinking tag - direct to content",
steps: []step{
{
input: "just content here",
wantEvents: []glm46Event{
glm46EventContent{content: "just content here"},
},
},
},
},
{
desc: "no thinking tag - skip to content then tool call",
steps: []step{
{
input: "Here's the answer:<tool_call>test</tool_call>done",
wantEvents: []glm46Event{
glm46EventContent{content: "Here's the answer:"},
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "done"},
},
},
},
},
{
desc: "no thinking tag - whitespace preserved when no tags",
steps: []step{
{
input: " \n content with leading whitespace",
wantEvents: []glm46Event{
glm46EventContent{content: " \n content with leading whitespace"},
},
},
},
},
{
desc: "whitespace after think close tag gets eaten",
steps: []step{
{
input: "<think>thinking</think> \n\t content",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "whitespace after tool_call close tag gets eaten",
steps: []step{
{
input: "<think></think><tool_call>test</tool_call> \n\t content",
wantEvents: []glm46Event{
glm46EventRawToolCall{raw: "test"},
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "thinking content withholds trailing whitespace (single chunk)",
steps: []step{
{
input: "<think>thinking content ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking content"},
},
},
{
input: "</think>after",
wantEvents: []glm46Event{
glm46EventContent{content: "after"},
},
},
},
},
{
desc: "thinking content withholds trailing whitespace with newlines",
steps: []step{
{
input: "<think>thinking\n\n ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "</think>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
{
desc: "thinking content trailing whitespace emitted when more content arrives",
steps: []step{
{
input: "<think>thinking ",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "more thinking",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: " more thinking"},
},
},
{
input: "</think>",
wantEvents: []glm46Event{},
},
},
},
{
desc: "thinking content withholds trailing whitespace before partial close tag",
steps: []step{
{
input: "<think>thinking </th",
wantEvents: []glm46Event{
glm46EventThinkingContent{content: "thinking"},
},
},
{
input: "ink>content",
wantEvents: []glm46Event{
glm46EventContent{content: "content"},
},
},
},
},
}
anyOnlies := false
for _, tc := range cases {
if tc.only {
anyOnlies = true
}
}
for _, tc := range cases {
if anyOnlies && !tc.only {
continue
}
t.Run(tc.desc, func(t *testing.T) {
parser := GLM46Parser{}
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
gotEvents := parser.parseEvents()
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
// avoid deep equal on empty vs. nil slices
continue
}
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
}
}
})
}
}
// TestGLMToolCallXMLOrderPreservation verifies that xml.Unmarshal preserves
// document order when collecting multiple elements with the same tag name into slices.
// This is a critical assumption for the GLM-4.6 parser's struct-based approach.
func TestGLMToolCallXMLOrderPreservation(t *testing.T) {
testCases := []struct {
name string
xml string
wantKeys []string
wantValues []string
}{
{
name: "alternating keys and values",
xml: `<tool_call>
function_name
<arg_key>first</arg_key>
<arg_value>A</arg_value>
<arg_key>second</arg_key>
<arg_value>B</arg_value>
<arg_key>third</arg_key>
<arg_value>C</arg_value>
</tool_call>`,
wantKeys: []string{"first", "second", "third"},
wantValues: []string{"A", "B", "C"},
},
{
name: "all keys then all values",
xml: `<tool_call>
function_name
<arg_key>key1</arg_key>
<arg_key>key2</arg_key>
<arg_key>key3</arg_key>
<arg_value>val1</arg_value>
<arg_value>val2</arg_value>
<arg_value>val3</arg_value>
</tool_call>`,
wantKeys: []string{"key1", "key2", "key3"},
wantValues: []string{"val1", "val2", "val3"},
},
{
name: "mixed grouping",
xml: `<tool_call>
function_name
<arg_key>a</arg_key>
<arg_value>1</arg_value>
<arg_key>b</arg_key>
<arg_key>c</arg_key>
<arg_value>2</arg_value>
<arg_value>3</arg_value>
</tool_call>`,
wantKeys: []string{"a", "b", "c"},
wantValues: []string{"1", "2", "3"},
},
{
name: "reverse order - all values then all keys",
xml: `<tool_call>
function_name
<arg_value>X</arg_value>
<arg_value>Y</arg_value>
<arg_value>Z</arg_value>
<arg_key>x</arg_key>
<arg_key>y</arg_key>
<arg_key>z</arg_key>
</tool_call>`,
wantKeys: []string{"x", "y", "z"},
wantValues: []string{"X", "Y", "Z"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var parsed GLMToolCallXML
err := xml.Unmarshal([]byte(tc.xml), &parsed)
if err != nil {
t.Fatalf("failed to unmarshal XML: %v", err)
}
if !reflect.DeepEqual(parsed.Keys, tc.wantKeys) {
t.Errorf("Keys order mismatch:\ngot: %v\nwant: %v", parsed.Keys, tc.wantKeys)
}
if !reflect.DeepEqual(parsed.Values, tc.wantValues) {
t.Errorf("Values order mismatch:\ngot: %v\nwant: %v", parsed.Values, tc.wantValues)
}
})
}
}
func TestGLM46ToolCallParsing(t *testing.T) {
type testCase struct {
name string
rawToolCall string
tools []api.Tool
wantToolCall api.ToolCall
}
cases := []testCase{
{
name: "simple tool call",
tools: []api.Tool{},
rawToolCall: `get-current-weather
<arg_key>location</arg_key>
<arg_value>New York, NY</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-current-weather",
Arguments: args(`{"location": "New York, NY", "unit": "celsius"}`),
},
},
},
{
name: "tool call with typed parameters",
tools: []api.Tool{
tool("calculate", map[string]api.ToolProperty{
"x": {Type: api.PropertyType{"number"}},
"y": {Type: api.PropertyType{"integer"}},
"enabled": {Type: api.PropertyType{"boolean"}},
"items": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `calculate
<arg_key>x</arg_key>
<arg_value>3.14</arg_value>
<arg_key>y</arg_key>
<arg_value>42</arg_value>
<arg_key>enabled</arg_key>
<arg_value>true</arg_value>
<arg_key>items</arg_key>
<arg_value>["a", "b", "c"]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: args(`{"enabled": true, "items": ["a", "b", "c"], "x": 3.14, "y": 42}`),
},
},
},
{
name: "function name with whitespace",
tools: []api.Tool{},
rawToolCall: ` get-weather
<arg_key>city</arg_key>
<arg_value>Paris</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-weather",
Arguments: args(`{"city": "Paris"}`),
},
},
},
{
name: "values with special characters",
tools: []api.Tool{},
rawToolCall: `execute-command
<arg_key>command</arg_key>
<arg_value>ls && echo "done"</arg_value>
<arg_key>message</arg_key>
<arg_value>a < b and c > d</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "execute-command",
Arguments: args(`{"command": "ls && echo \"done\"", "message": "a < b and c > d"}`),
},
},
},
{
name: "unicode in function names and values",
tools: []api.Tool{},
rawToolCall: `获取天气
<arg_key>城市</arg_key>
<arg_value>北京</arg_value>
<arg_key>message</arg_key>
<arg_value>Hello! 你好! 🌟</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: args(`{"message": "Hello! 你好! 🌟", "城市": "北京"}`),
},
},
},
{
name: "empty value",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param1</arg_key>
<arg_value></arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param1": ""}`),
},
},
},
{
name: "special chars in arg_key names",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param<1></arg_key>
<arg_value>value1</arg_value>
<arg_key>a&b</arg_key>
<arg_value>value2</arg_value>
<arg_key>x>y</arg_key>
<arg_value>value3</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"a&b": "value2", "param<1>": "value1", "x>y": "value3"}`),
},
},
},
{
name: "multiple consecutive ampersands",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value>test &&&& more</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param": "test &&&& more"}`),
},
},
},
{
name: "mixed special chars together",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value><>&<>&</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"param": "<>&<>&"}`),
},
},
},
{
name: "newlines and tabs in parameter values",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>multiline</arg_key>
<arg_value>line1
indented line2
line3</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"multiline": "line1\n\tindented line2\nline3"}`),
},
},
},
{
name: "single and double quotes in values",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>quotes</arg_key>
<arg_value>She said "Hello's there!"</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"quotes": "She said \"Hello's there!\""}`),
},
},
},
{
name: "CDATA-like content that should be treated as text",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>cdata</arg_key>
<arg_value><![CDATA[not actual cdata]]></arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"cdata": "<![CDATA[not actual cdata]]>"}`),
},
},
},
{
name: "all special XML entities",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>entities</arg_key>
<arg_value>&lt;&gt;&amp;&apos;&quot;</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"entities": "&lt;&gt;&amp;&apos;&quot;"}`),
},
},
},
{
name: "order preservation with multiple parameters",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>first</arg_key>
<arg_value>value1</arg_value>
<arg_key>second</arg_key>
<arg_value>value2</arg_value>
<arg_key>third</arg_key>
<arg_value>value3</arg_value>
<arg_key>fourth</arg_key>
<arg_value>value4</arg_value>
<arg_key>fifth</arg_key>
<arg_value>value5</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
Arguments: args(`{"fifth": "value5", "first": "value1", "fourth": "value4", "second": "value2", "third": "value3"}`),
},
},
},
{
name: "order preservation with identical key names but different positions",
tools: []api.Tool{},
rawToolCall: `test-function
<arg_key>param</arg_key>
<arg_value>first occurrence</arg_value>
<arg_key>other</arg_key>
<arg_value>middle</arg_value>
<arg_key>param</arg_key>
<arg_value>second occurrence</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test-function",
// Later occurrence should overwrite earlier one
Arguments: args(`{"other": "middle", "param": "second occurrence"}`),
},
},
},
{
name: "array with mixed types",
tools: []api.Tool{
tool("process", map[string]api.ToolProperty{
"items": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `process
<arg_key>items</arg_key>
<arg_value>[1, "hello", true, null]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process",
Arguments: args(`{"items": [1, "hello", true, null]}`),
},
},
},
{
name: "empty array",
tools: []api.Tool{
tool("test", map[string]api.ToolProperty{
"tags": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `test
<arg_key>tags</arg_key>
<arg_value>[]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "test",
Arguments: args(`{"tags": []}`),
},
},
},
{
name: "anyOf array or string - with array of objects",
tools: []api.Tool{
tool("TodoWrite", map[string]api.ToolProperty{
"todos": {AnyOf: []api.ToolProperty{{Type: api.PropertyType{"array"}}, {Type: api.PropertyType{"string"}}}},
}),
},
// <tool_call>TodoWrite
// <arg_key>todos</arg_key>
// <arg_value>[{"content": "Set up HTML file and basic structure", "id": "1", "priority": "high", "status": "pending"}, {"content": "Create 3D scene with Three.js", "id": "2", "priority": "high", "status": "pending"}, {"content": "Implement terrain generation with blocks", "id": "3", "priority": "high", "status": "pending"}, {"content": "Add player controls (movement, camera)", "id": "4", "priority": "high", "status": "pending"}, {"content": "Implement block placement/destruction", "id": "5", "priority": "medium", "status": "pending"}, {"content": "Add lighting and textures", "id": "6", "priority": "medium", "status": "pending"}, {"content": "Test and optimize performance", "id": "7", "priority": "low", "status": "pending"}]</arg_value>
// </tool_call>
rawToolCall: `TodoWrite
<arg_key>todos</arg_key>
<arg_value>[{"content": "task 1", "status": "pending", "priority": "high", "id": "1"}, {"content": "task 2", "status": "completed", "priority": "low", "id": "2"}]</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "TodoWrite",
Arguments: args(`{"todos": [{"content": "task 1", "id": "1", "priority": "high", "status": "pending"}, {"content": "task 2", "id": "2", "priority": "low", "status": "completed"}]}`),
},
},
},
{
name: "anyOf array or string - with plain string",
tools: []api.Tool{
tool("TodoWrite", map[string]api.ToolProperty{
"todos": {Type: api.PropertyType{"array", "string"}},
}),
},
rawToolCall: `TodoWrite
<arg_key>todos</arg_key>
<arg_value>Error: could not load todos</arg_value>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "TodoWrite",
Arguments: args(`{"todos": "Error: could not load todos"}`),
},
},
},
}
for i, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
gotToolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: tc.rawToolCall}, tc.tools)
if err != nil {
t.Errorf("case %d (%s): %v", i, tc.name, err)
}
if !toolCallEqual(gotToolCall, tc.wantToolCall) {
t.Errorf("case %d (%s): got tool call %#v, want %#v", i, tc.name, gotToolCall, tc.wantToolCall)
}
})
}
}

View File

@@ -1,20 +0,0 @@
package parsers
import "github.com/ollama/ollama/api"
// GLM47Parser extends GLM46Parser with thinking-aware initialization.
// GLM-4.7's prompt ends with <think> when thinking is enabled, so the parser
// must start in CollectingThinking state (the model outputs thinking content directly).
type GLM47Parser struct {
GLM46Parser
}
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
// When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() {
p.state = glm46ParserState_CollectingThinking
}
return tools
}

View File

@@ -1,99 +0,0 @@
package parsers
import (
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
func TestGLM47ParserAdd(t *testing.T) {
parser := GLM47Parser{}
parser.Init([]api.Tool{
tool("calculate", map[string]api.ToolProperty{
"count": {Type: api.PropertyType{"integer"}},
"enabled": {Type: api.PropertyType{"boolean"}},
}),
}, nil, nil)
// When thinking is enabled (thinkValue nil), the prompt ends with <think>,
// so the model output does NOT include the opening <think> tag.
content, thinking, calls, err := parser.Add("plan</think>Answer<tool_call>calculate<arg_key>count</arg_key><arg_value>3</arg_value><arg_key>enabled</arg_key><arg_value>true</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "plan" {
t.Fatalf("expected thinking 'plan', got %q", thinking)
}
if content != "Answer" {
t.Fatalf("expected content 'Answer', got %q", content)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
expectedArgs := args(`{"count": 3, "enabled": true}`)
if !toolCallEqual(api.ToolCall{Function: api.ToolCallFunction{Arguments: calls[0].Function.Arguments}}, api.ToolCall{Function: api.ToolCallFunction{Arguments: expectedArgs}}) {
t.Fatalf("expected args %#v, got %#v", expectedArgs.ToMap(), calls[0].Function.Arguments.ToMap())
}
}
func TestGLM47ParserNoThinkingContent(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
// When thinking is enabled but model has no thinking to output,
// it should output </think> immediately followed by content.
content, thinking, calls, err := parser.Add("</think>Plain answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if content != "Plain answer" {
t.Fatalf("expected content 'Plain answer', got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestGLM47ParserThinkingDisabled(t *testing.T) {
parser := GLM47Parser{}
// When thinking is disabled, parser stays in LookingForThinkingOpen state
parser.Init(nil, nil, &api.ThinkValue{Value: false})
// Model outputs plain content (prompt ended with </think>)
content, thinking, calls, err := parser.Add("Plain answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if content != "Plain answer" {
t.Fatalf("expected content 'Plain answer', got %q", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestGLM47ParserToolCallEscaping(t *testing.T) {
toolCall, err := parseGLM46ToolCall(glm46EventRawToolCall{raw: `exec
<arg_key>expr</arg_key>
<arg_value>a < b && c > d</arg_value>`}, nil)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
expected := api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: args(`{"expr": "a < b && c > d"}`),
},
}
if !reflect.DeepEqual(toolCall, expected) {
t.Fatalf("expected %#v, got %#v", expected, toolCall)
}
}

View File

@@ -1,498 +0,0 @@
package parsers
import (
"encoding/json"
"errors"
"log/slog"
"strconv"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
type LFM2ParserState int
const (
LFM2CollectingThinking LFM2ParserState = iota
LFM2CollectingContent
LFM2CollectingToolCalls
)
const (
lfm2ThinkingOpenTag = "<think>"
lfm2ThinkingCloseTag = "</think>"
lfm2ToolCallStartTag = "<|tool_call_start|>"
lfm2ToolCallEndTag = "<|tool_call_end|>"
)
type LFM2Parser struct {
state LFM2ParserState
buffer strings.Builder
hasThinkingSupport bool
needsThinkingLeadingTrim bool // trim leading whitespace after <think> tag
needsContentLeadingTrim bool // trim leading whitespace after </think> tag
}
func (p *LFM2Parser) HasToolSupport() bool {
return true
}
func (p *LFM2Parser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.ThinkValue) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
// Check both model capability AND request preference
thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
if !thinkingEnabled {
p.state = LFM2CollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = LFM2CollectingContent
return
}
p.state = LFM2CollectingThinking
p.needsThinkingLeadingTrim = true
}
func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.setInitialState(lastMessage, thinkValue)
return tools
}
type lfm2Event interface {
isLFM2Event()
}
type lfm2EventThinkingContent struct {
content string
}
type lfm2EventContent struct {
content string
}
type lfm2EventToolCall struct {
toolCall api.ToolCall
}
func (lfm2EventThinkingContent) isLFM2Event() {}
func (lfm2EventContent) isLFM2Event() {}
func (lfm2EventToolCall) isLFM2Event() {}
func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case lfm2EventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case lfm2EventThinkingContent:
thinkingSb.WriteString(event.content)
case lfm2EventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *LFM2Parser) parseEvents() []lfm2Event {
var all []lfm2Event
keepLooping := true
for keepLooping {
var events []lfm2Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
var events []lfm2Event
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case LFM2CollectingThinking:
// Strip opening <think> tag if present
if strings.HasPrefix(bufStr, lfm2ThinkingOpenTag) {
bufStr = bufStr[len(lfm2ThinkingOpenTag):]
p.needsThinkingLeadingTrim = true
p.buffer.Reset()
p.buffer.WriteString(bufStr)
}
// Trim leading whitespace after <think> tag (may span multiple chunks)
if p.needsThinkingLeadingTrim {
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
bufStr = trimmed
p.buffer.Reset()
p.buffer.WriteString(bufStr)
}
// Clear flag once we have non-whitespace content or buffer is empty
if len(bufStr) > 0 {
p.needsThinkingLeadingTrim = false
}
}
if strings.Contains(bufStr, lfm2ThinkingCloseTag) { // thinking[</think>] -> content
split := strings.SplitN(bufStr, lfm2ThinkingCloseTag, 2)
thinking := split[0]
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
remaining := split[1]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = LFM2CollectingContent
p.needsThinkingLeadingTrim = false
// Set flag to trim any additional whitespace that may arrive in later chunks
p.needsContentLeadingTrim = len(remaining) == 0
if len(thinking) > 0 {
events = append(events, lfm2EventThinkingContent{content: thinking})
}
return events, true
} else if overlapLen := overlap(bufStr, lfm2ThinkingCloseTag); overlapLen > 0 { // partial </think>
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, lfm2EventThinkingContent{content: unambiguous})
}
return events, false
} else { // otherwise its thinking content
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, lfm2EventThinkingContent{content: unambiguous})
}
return events, false
}
case LFM2CollectingContent:
// Trim leading whitespace after </think> tag (may span multiple chunks)
if p.needsContentLeadingTrim {
if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
bufStr = trimmed
p.buffer.Reset()
p.buffer.WriteString(bufStr)
}
// Clear flag once we have non-whitespace content
if len(bufStr) > 0 {
p.needsContentLeadingTrim = false
}
}
if strings.Contains(bufStr, lfm2ToolCallStartTag) { // content[<|tool_call_start|>] -> tool calls
split := strings.SplitN(bufStr, lfm2ToolCallStartTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = LFM2CollectingToolCalls
if len(contentBefore) > 0 {
events = append(events, lfm2EventContent{content: contentBefore})
}
return events, true
} else { // otherwise its content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, lfm2EventContent{content: bufStr})
}
return events, false
}
case LFM2CollectingToolCalls:
// Look for complete tool call JSON between tags
if idx := strings.Index(bufStr, lfm2ToolCallEndTag); idx != -1 {
toolCallContent := bufStr[:idx]
if toolCalls, err := p.parseToolCallsContent(toolCallContent); err == nil && len(toolCalls) > 0 {
remaining := bufStr[idx+len(lfm2ToolCallEndTag):]
// Check if there's another tool call
if strings.HasPrefix(remaining, lfm2ToolCallStartTag) {
remaining = remaining[len(lfm2ToolCallStartTag):]
} else {
// No more tool calls, go back to content
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.state = LFM2CollectingContent
}
p.buffer.Reset()
p.buffer.WriteString(remaining)
for _, tc := range toolCalls {
events = append(events, lfm2EventToolCall{toolCall: tc})
}
return events, true
} else if err != nil {
slog.Warn("lfm2 tool call parsing failed", "error", err, "content", toolCallContent)
}
}
return events, false
}
return events, false
}
// parseToolCallsContent parses one or more tool calls from content
// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)]
func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) {
content = strings.TrimSpace(content)
// Try JSON format first: {"name": "func", "arguments": {...}}
var parsed struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
var args api.ToolCallFunctionArguments
if len(parsed.Arguments) > 0 {
if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
return nil, err
}
} else {
args = api.NewToolCallFunctionArguments()
}
return []api.ToolCall{{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: args,
},
}}, nil
}
// Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
return p.parsePythonStyleToolCalls(content)
}
// parsePythonStyleToolCalls parses one or more Python-style tool calls
// Examples: [bash(command='ls'),bash(command='pwd')] or bash(command='ls')
func (p *LFM2Parser) parsePythonStyleToolCalls(content string) ([]api.ToolCall, error) {
content = strings.TrimSpace(content)
// Strip outer brackets if present: [func(...)] -> func(...)
if strings.HasPrefix(content, "[") && strings.HasSuffix(content, "]") {
content = content[1 : len(content)-1]
}
var toolCalls []api.ToolCall
// Parse multiple function calls separated by commas at the top level
for len(content) > 0 {
content = strings.TrimSpace(content)
if content == "" {
break
}
// Skip leading comma from previous iteration
if strings.HasPrefix(content, ",") {
content = strings.TrimSpace(content[1:])
if content == "" {
break
}
}
// Find function name
parenIdx := strings.Index(content, "(")
if parenIdx == -1 {
return nil, errors.New("invalid tool call: no opening parenthesis")
}
funcName := strings.TrimSpace(content[:parenIdx])
if funcName == "" {
return nil, errors.New("invalid tool call: empty function name")
}
// Find matching closing parenthesis
closeIdx := findMatchingParen(content, parenIdx)
if closeIdx == -1 {
return nil, errors.New("invalid tool call: no matching closing parenthesis")
}
argsStr := content[parenIdx+1 : closeIdx]
args := api.NewToolCallFunctionArguments()
if argsStr != "" {
if err := parsePythonArgs(argsStr, &args); err != nil {
return nil, err
}
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: funcName,
Arguments: args,
},
})
// Move past this function call
content = content[closeIdx+1:]
}
if len(toolCalls) == 0 {
return nil, errors.New("no tool calls found")
}
return toolCalls, nil
}
// findMatchingParen finds the index of the closing parenthesis matching the one at openIdx
// Returns -1 if not found. Handles nested parentheses and quoted strings.
func findMatchingParen(s string, openIdx int) int {
depth := 1
i := openIdx + 1
for i < len(s) && depth > 0 {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 {
return i
}
case '\'', '"':
// Skip quoted string
quote := s[i]
i++
for i < len(s) && s[i] != quote {
if s[i] == '\\' && i+1 < len(s) {
i++ // skip escaped char
}
i++
}
}
i++
}
return -1
}
// parseToolCallContent parses a single tool call (for backward compatibility with tests)
func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error) {
calls, err := p.parseToolCallsContent(content)
if err != nil {
return api.ToolCall{}, err
}
if len(calls) == 0 {
return api.ToolCall{}, errors.New("no tool call found")
}
return calls[0], nil
}
// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
// Simple state machine to parse key='value' pairs
// Handles: command='ls', flag="-la", count=42, enabled=true
var key string
i := 0
for i < len(argsStr) {
// Skip whitespace
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
i++
}
if i >= len(argsStr) {
break
}
// Parse key
keyStart := i
for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
i++
}
if i >= len(argsStr) || argsStr[i] != '=' {
return errors.New("invalid argument: expected '='")
}
key = strings.TrimSpace(argsStr[keyStart:i])
i++ // skip '='
// Skip whitespace after =
for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
i++
}
// Parse value
var value string
if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
// Quoted string
quote := argsStr[i]
i++
valueStart := i
for i < len(argsStr) && argsStr[i] != quote {
if argsStr[i] == '\\' && i+1 < len(argsStr) {
i += 2 // skip escaped char
} else {
i++
}
}
value = argsStr[valueStart:i]
if i < len(argsStr) {
i++ // skip closing quote
}
args.Set(key, value)
} else {
// Unquoted value (number, bool, etc)
valueStart := i
for i < len(argsStr) && argsStr[i] != ',' {
i++
}
value = strings.TrimSpace(argsStr[valueStart:i])
// Try to parse as number or bool
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
args.Set(key, v)
} else if v, err := strconv.ParseFloat(value, 64); err == nil {
args.Set(key, v)
} else if value == "true" {
args.Set(key, true)
} else if value == "false" {
args.Set(key, false)
} else {
args.Set(key, value)
}
}
// Skip comma and whitespace
for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
i++
}
}
return nil
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -68,12 +68,6 @@ func ParserForName(name string) Parser {
return &Nemotron3NanoParser{}
case "functiongemma":
return &FunctionGemmaParser{}
case "glm-4.7":
return &GLM47Parser{}
case "lfm2":
return &LFM2Parser{hasThinkingSupport: false}
case "lfm2-thinking":
return &LFM2Parser{hasThinkingSupport: true}
default:
return nil
}

View File

@@ -96,11 +96,3 @@ func testArgs(m map[string]any) api.ToolCallFunctionArguments {
}
return args
}
func args(s string) api.ToolCallFunctionArguments {
var result api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(s), &result); err != nil {
panic("invalid JSON in args(): " + err.Error())
}
return result
}

View File

@@ -1,110 +0,0 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
type GLM46Renderer struct{}
func (r *GLM46Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
var lastUserIndex int
for i, message := range messages {
if message.Role == "user" {
lastUserIndex = i
}
}
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(string(d) + "\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}\n")
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
sb.WriteString("...\n")
sb.WriteString("</tool_call>")
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>\n")
sb.WriteString(message.Content)
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
sb.WriteString("/nothink")
}
case "assistant":
sb.WriteString("<|assistant|>")
if i > lastUserIndex {
if message.Thinking != "" {
sb.WriteString("\n<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("\n<think></think>")
}
}
if message.Content != "" {
sb.WriteString("\n" + message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
for key, value := range toolCall.Function.Arguments.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
}
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>")
case "system":
sb.WriteString("<|system|>\n")
sb.WriteString(message.Content)
}
}
// Add generation prompt
sb.WriteString("<|assistant|>")
if thinkValue != nil && !thinkValue.Bool() {
sb.WriteString("\n<think></think>\n")
}
return sb.String(), nil
}

View File

@@ -1,223 +0,0 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestGLM46Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
skip string
}{
{
name: "basic",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|system|>
You are a helpful assistant.<|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with user assistant user",
messages: []api.Message{
{Role: "user", Content: "What is the capital of France?"},
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
{Role: "user", Content: "Fantastic!"},
},
expected: `[gMASK]<sop><|user|>
What is the capital of France?<|assistant|>
The capital of France is Paris.<|user|>
Fantastic!<|assistant|>`,
},
{
skip: "tool call ordering not guaranteed yet",
name: "tools",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>`,
},
{
skip: "tool call ordering not guaranteed yet",
name: "tool calls",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo, Japan", "unit": "celsius"}`),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Japan", "unit": "fahrenheit"}`),
},
},
},
},
{
Role: "tool",
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
ToolName: "get_weather",
},
{
Role: "tool",
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
ToolName: "get_weather",
},
{
Role: "assistant",
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}`),
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>
<think></think>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Tokyo, Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>
</tool_call>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call><|observation|>
<tool_response>
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
</tool_response>
<tool_response>
{"temperature": 68, "weather": "sunny", "humidity": 75}
</tool_response><|assistant|>
<think></think>
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
},
{
name: "think true",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "think false",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `[gMASK]<sop><|user|>
Hello, how are you?/nothink<|assistant|>
<think></think>
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skip != "" {
t.Skip(tt.skip)
}
renderer := &GLM46Renderer{}
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
t.Logf("Got:\n%s", rendered)
t.Logf("Expected:\n%s", tt.expected)
}
})
}
}

View File

@@ -1,170 +0,0 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
// GLM47Renderer renders messages for GLM-4.7 models.
//
// GLM-4.7 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
//
// 1. INTERLEAVED THINKING
// The model thinks between tool calls and after receiving tool results.
// This enables complex step-by-step reasoning: interpreting each tool output
// before deciding what to do next. Thinking blocks are preserved and returned
// with tool results to maintain reasoning continuity.
//
// 2. PRESERVED THINKING
// The model retains reasoning content from previous assistant turns in context.
// This preserves reasoning continuity across multi-turn conversations. The
// upstream API has a "clear_thinking" parameter to control this:
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
//
// 3. TURN-LEVEL THINKING
// Controls whether the model should reason on each turn. The upstream API
// uses "enable_thinking" parameter:
// - enable_thinking=true: outputs <think> to start reasoning
// - enable_thinking=false: outputs </think> to skip reasoning
//
// OLLAMA DEFAULTS:
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
// - Users can disable thinking per-turn via thinkValue=false
type GLM47Renderer struct{}
func (r *GLM47Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatGLM47ToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
think := true
if thinkValue != nil && !thinkValue.Bool() {
think = false
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>")
sb.WriteString(message.Content)
case "assistant":
sb.WriteString("<|assistant|>")
if message.Thinking != "" {
sb.WriteString("<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("</think>")
}
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderGLM47ToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("<tool_response>")
sb.WriteString(message.Content)
sb.WriteString("</tool_response>")
case "system":
sb.WriteString("<|system|>")
sb.WriteString(message.Content)
}
}
sb.WriteString("<|assistant|>")
if think {
sb.WriteString("<think>")
} else {
sb.WriteString("</think>")
}
return sb.String(), nil
}
func renderGLM47ToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}
func formatGLM47ToolJSON(raw []byte) string {
var sb strings.Builder
sb.Grow(len(raw) + len(raw)/10)
inString := false
escaped := false
for i := range raw {
ch := raw[i]
sb.WriteByte(ch)
if inString {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
continue
}
if ch == ':' || ch == ',' {
sb.WriteByte(' ')
}
}
return sb.String()
}

View File

@@ -1,191 +0,0 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestGLM47Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic user message",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
expected: "[gMASK]<sop><|user|>Hello<|assistant|><think>",
},
{
name: "thinking disabled",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "[gMASK]<sop><|user|>Hello<|assistant|></think>",
},
{
name: "system and user",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
expected: "[gMASK]<sop><|system|>You are helpful.<|user|>Hello<|assistant|><think>",
},
{
name: "multi-turn conversation",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello there"},
{Role: "user", Content: "How are you?"},
},
expected: "[gMASK]<sop><|user|>Hi<|assistant|></think>Hello there<|user|>How are you?<|assistant|><think>",
},
{
name: "assistant with reasoning_content",
messages: []api.Message{
{Role: "user", Content: "Answer with reasoning."},
{Role: "assistant", Thinking: "Plan.", Content: "Done."},
},
expected: "[gMASK]<sop><|user|>Answer with reasoning.<|assistant|><think>Plan.</think>Done.<|assistant|><think>",
},
{
name: "tool call with empty content",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo", "unit": "celsius"}`),
},
},
},
},
{Role: "tool", Content: `{"temperature":22}`},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string"}}`),
},
},
},
},
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|><think>",
},
{
name: "tool call with content",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
Content: "Let me check",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo"}`),
},
},
},
},
{Role: "tool", Content: `{"temperature":22}`},
{Role: "assistant", Content: "It is 22C."},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string"}}`),
},
},
},
},
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Weather?<|assistant|></think>Let me check<tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><|assistant|></think>It is 22C.<|assistant|><think>",
},
{
name: "multiple tool calls and responses",
messages: []api.Message{
{Role: "user", Content: "Compare weather"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Tokyo"}`),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args(`{"location": "Paris"}`),
},
},
},
},
{Role: "tool", Content: `{"temperature":22}`},
{Role: "tool", Content: `{"temperature":18}`},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: propsMap(`{"location": {"type": "string"}}`),
},
},
},
},
expected: "[gMASK]<sop><|system|>\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"description\": \"Get weather\", \"parameters\": {\"type\": \"object\", \"required\": [\"location\"], \"properties\": {\"location\": {\"type\": \"string\"}}}}}\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call><|user|>Compare weather<|assistant|></think><tool_call>get_weather<arg_key>location</arg_key><arg_value>Tokyo</arg_value></tool_call><tool_call>get_weather<arg_key>location</arg_key><arg_value>Paris</arg_value></tool_call><|observation|><tool_response>{\"temperature\":22}</tool_response><tool_response>{\"temperature\":18}</tool_response><|assistant|><think>",
},
{
name: "preserved thinking in multi-turn",
messages: []api.Message{
{Role: "user", Content: "Think step by step"},
{Role: "assistant", Thinking: "Let me think...", Content: "Here's my answer."},
{Role: "user", Content: "Continue"},
},
expected: "[gMASK]<sop><|user|>Think step by step<|assistant|><think>Let me think...</think>Here's my answer.<|user|>Continue<|assistant|><think>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
renderer := &GLM47Renderer{}
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
t.Logf("Got:\n%s", rendered)
t.Logf("Expected:\n%s", tt.expected)
}
})
}
}

View File

@@ -1,144 +0,0 @@
package renderers
import (
"encoding/json"
"strings"
"github.com/ollama/ollama/api"
)
type LFM2Renderer struct {
IsThinking bool
}
func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
// Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
// Extract first system message if present (to combine with tools)
var firstSystemContent string
startIdx := 0
if len(messages) > 0 && messages[0].Role == "system" {
firstSystemContent = messages[0].Content
startIdx = 1
}
// Append tools to first system content
if len(tools) > 0 {
if firstSystemContent != "" {
firstSystemContent += "\n"
}
firstSystemContent += "List of tools: ["
for i, tool := range tools {
toolJSON, err := json.Marshal(tool)
if err != nil {
return "", err
}
firstSystemContent += string(toolJSON)
if i < len(tools)-1 {
firstSystemContent += ", "
}
}
firstSystemContent += "]"
}
// Output first system block if it has content
if firstSystemContent != "" {
sb.WriteString("<|im_start|>system\n")
sb.WriteString(firstSystemContent)
sb.WriteString("<|im_end|>\n")
}
// Find the index of the last assistant message for thinking stripping
lastAssistantIndex := -1
for i := len(messages) - 1; i >= startIdx; i-- {
if messages[i].Role == "assistant" {
lastAssistantIndex = i
break
}
}
// Track whether we need to add generation prompt
needsGenerationPrompt := len(messages) > 0
for i := startIdx; i < len(messages); i++ {
message := messages[i]
switch message.Role {
case "system":
// Additional system messages (after the first) are rendered normally
sb.WriteString("<|im_start|>system\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
case "user":
sb.WriteString("<|im_start|>user\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
case "assistant":
sb.WriteString("<|im_start|>assistant\n")
// Check if this is the last assistant message
isLastAssistant := i == lastAssistantIndex
// Process content (may need thinking stripped)
content := message.Content
// Handle thinking tags in assistant content
keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
if strings.Contains(content, "</think>") {
parts := strings.SplitN(content, "</think>", 2)
if len(parts) > 1 {
if !isLastAssistant && !keepPastThinking {
// Strip thinking entirely for past assistant messages
content = strings.TrimSpace(parts[1])
} else {
// Preserve thinking but trim whitespace after </think>
content = parts[0] + "</think>" + strings.TrimLeft(parts[1], " \t\n\r")
}
}
}
if len(message.ToolCalls) > 0 {
// Assistant with tool calls - write content first (if any after stripping)
if content != "" {
sb.WriteString(content)
}
for _, toolCall := range message.ToolCalls {
sb.WriteString("<|tool_call_start|>")
toolCallJSON := map[string]any{
"name": toolCall.Function.Name,
"arguments": toolCall.Function.Arguments,
}
callJSON, _ := json.Marshal(toolCallJSON)
sb.WriteString(string(callJSON))
sb.WriteString("<|tool_call_end|>")
}
} else {
sb.WriteString(content)
}
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
case "tool":
// Tool responses are rendered as plain messages per the chat template
sb.WriteString("<|im_start|>tool\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
needsGenerationPrompt = true
}
}
// Add generation prompt
if needsGenerationPrompt {
sb.WriteString("<|im_start|>assistant\n")
// Note: Model is a "thinking-only" model - it will output <think> itself
// We don't add <think> tag to the prompt
}
return sb.String(), nil
}

View File

@@ -1,427 +0,0 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestLFM2Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic user message",
messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple system messages rendered separately",
messages: []api.Message{
{Role: "system", Content: "First instruction."},
{Role: "system", Content: "Second instruction."},
{Role: "user", Content: "Hello!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multi-turn conversation",
messages: []api.Message{
{Role: "user", Content: "What is 2+2?"},
{Role: "assistant", Content: "The answer is 4."},
{Role: "user", Content: "Thanks!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "only system message",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
},
{
// When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
name: "user-assistant-user: last assistant preserves thinking",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reasoning</think>A1"},
{Role: "user", Content: "Q2"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reasoning</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
},
{
// With two assistants, first is stripped (not last), second preserved (is last)
name: "multi-turn thinking: first stripped, second preserved",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
{Role: "user", Content: "Q2"},
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
},
{
// With thinking enabled (keep_past_thinking=true), both preserved
name: "multi-turn thinking: both preserved when thinking enabled",
messages: []api.Message{
{Role: "user", Content: "Q1"},
{Role: "assistant", Content: "<think>reason1</think>A1"},
{Role: "user", Content: "Q2"},
{Role: "assistant", Content: "<think>reason2</think>A2"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n<think>reason1</think>A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n<think>reason2</think>A2<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "assistant with tool calls",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "assistant with content and tool calls",
messages: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: "Let me check.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tool response",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "assistant", Content: "Let me check."},
{Role: "tool", Content: "22C, Sunny"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "multiple tool calls",
messages: []api.Message{
{Role: "user", Content: "Get weather for Paris and London"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "London",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions with system message",
messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "tools definitions without system message",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
}),
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
},
{
name: "multiple tools without system message",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "get_time",
Description: "Get time",
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "user-tool sequence",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "tool", Content: "22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "full tool call cycle",
messages: []api.Message{
{Role: "user", Content: "Check weather"},
{Role: "assistant", Content: "Let me check"},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "It's 22C"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "unicode content",
messages: []api.Message{
{Role: "user", Content: "你好世界! مرحبا 🌍"},
{Role: "assistant", Content: "Hello! 👋"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "newlines in content",
messages: []api.Message{
{Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
{Role: "assistant", Content: "Response with\nmultiple\nlines"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
},
{
name: "empty assistant content",
messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: ""},
{Role: "user", Content: "OK"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
},
{
// Generation prompt does NOT include <think> - model outputs it
name: "generation prompt has no think tag",
messages: []api.Message{
{Role: "user", Content: "Think hard"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
},
{
// Interleaved: thinking before tool call - last assistant preserves thinking
name: "thinking before tool call (last assistant)",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>I need to check the weather</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>I need to check the weather</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tool calls - first has thinking stripped
name: "two assistants with tools: first thinking stripped",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Two assistants with tools - both preserved when thinking enabled
name: "two assistants with tools: both preserved when thinking enabled",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "<think>checking</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
{Role: "tool", Content: "22C"},
{Role: "assistant", Content: "<think>got result</think>It's 22C!"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<think>checking</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n<think>got result</think>It's 22C!<|im_end|>\n<|im_start|>assistant\n",
},
{
// Content before thinking before tool call
name: "content then thinking then tool call",
messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "Let me check.<think>Using weather API</think>",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "Paris",
}),
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<think>Using weather API</think><|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
},
}
renderer := &LFM2Renderer{IsThinking: true}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -80,12 +80,6 @@ func rendererForName(name string) Renderer {
return &Nemotron3NanoRenderer{}
case "functiongemma":
return &FunctionGemmaRenderer{}
case "glm-4.7":
return &GLM47Renderer{}
case "lfm2":
return &LFM2Renderer{IsThinking: false}
case "lfm2-thinking":
return &LFM2Renderer{IsThinking: true}
default:
return nil
}

View File

@@ -1,26 +1,6 @@
package renderers
import (
"encoding/json"
"github.com/ollama/ollama/api"
)
func args(s string) api.ToolCallFunctionArguments {
var result api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(s), &result); err != nil {
panic("invalid JSON in args(): " + err.Error())
}
return result
}
func propsMap(s string) *api.ToolPropertiesMap {
var result api.ToolPropertiesMap
if err := json.Unmarshal([]byte(s), &result); err != nil {
panic("invalid JSON in propsMap(): " + err.Error())
}
return &result
}
import "github.com/ollama/ollama/api"
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {

View File

@@ -37,6 +37,7 @@ import (
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/tokenizers"
_ "github.com/ollama/ollama/model/models"
)
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
func calculateLogprobs(logits []float32, selectedToken int32, topK int, tokenizer tokenizers.Tokenizer) []llm.Logprob {
decoder := func(tokenID int) string {
text, _ := textProcessor.Decode([]int32{int32(tokenID)})
text, _ := tokenizer.Decode([]int32{int32(tokenID)})
return text
}
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
tokens, err := s.model.(tokenizers.Tokenizer).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
}
@@ -766,7 +767,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
if s.model.(tokenizers.Tokenizer).Is(token, tokenizers.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
@@ -775,14 +776,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
piece, err := s.model.(tokenizers.Tokenizer).Decode([]int32{token})
if err != nil {
panic("failed to decode token")
}
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
if seq.logprobs {
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizers.Tokenizer))
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
}
@@ -873,7 +874,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
grammar, err = sample.NewGrammarSampler(s.model.(tokenizers.Tokenizer), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return

View File

@@ -7,7 +7,7 @@ import (
"slices"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizers"
)
// token represents information about a single token during sampling
@@ -168,15 +168,15 @@ type GrammarSampler struct {
grammar *llama.Grammar
}
func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(model.Vocabulary().Values))
pieces := make([]string, len(model.Vocabulary().Values))
for i := range model.Vocabulary().Values {
pieces[i], _ = model.Decode([]int32{int32(i)})
func NewGrammarSampler(tokenizer tokenizers.Tokenizer, grammarStr string) (*GrammarSampler, error) {
vocabIds := make([]uint32, len(tokenizer.Vocabulary().Values))
pieces := make([]string, len(tokenizer.Vocabulary().Values))
for i := range tokenizer.Vocabulary().Values {
pieces[i], _ = tokenizer.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tokenizer.Vocabulary().EOS)
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}

View File

@@ -28,7 +28,6 @@ import (
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@@ -91,7 +90,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- resp
}
oldManifest, _ := manifest.ParseNamedManifest(name)
oldManifest, _ := ParseNamedManifest(name)
var baseLayers []*layerGGML
var err error
@@ -124,9 +123,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
}
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
mf, mErr := manifest.ParseNamedManifest(fromName)
if mErr == nil && mf.Config.Digest != "" {
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
manifest, mErr := ParseNamedManifest(fromName)
if mErr == nil && manifest.Config.Digest != "" {
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
if pErr == nil {
if cfgFile, fErr := os.Open(configPath); fErr == nil {
var baseConfig model.ConfigV2
@@ -343,7 +342,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
return "gguf"
} else {
// try to see if we can find a gguf file even without the file extension
blobPath, err := manifest.BlobsPath(files[fn])
blobPath, err := GetBlobsPath(files[fn])
if err != nil {
slog.Error("error getting blobs path", "file", fn)
return ""
@@ -395,7 +394,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
}
blobPath, err := manifest.BlobsPath(digest)
blobPath, err := GetBlobsPath(digest)
if err != nil {
return nil, err
}
@@ -433,7 +432,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, err
}
layer, err := manifest.NewLayer(t, mediaType)
layer, err := NewLayer(t, mediaType)
if err != nil {
return nil, err
}
@@ -466,7 +465,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
}
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
var layers []manifest.Layer
var layers []Layer
for _, layer := range baseLayers {
if layer.GGML != nil {
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
@@ -551,13 +550,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
}
for _, layer := range layers {
if layer.Status != "" {
fn(api.ProgressResponse{Status: layer.Status})
if layer.status != "" {
fn(api.ProgressResponse{Status: layer.status})
}
}
fn(api.ProgressResponse{Status: "writing manifest"})
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
if err := WriteManifest(name, *configLayer, layers); err != nil {
return err
}
@@ -578,7 +577,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err
}
blob, err := manifest.BlobsPath(layer.Digest)
blob, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -600,7 +599,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
}
temp.Seek(0, io.SeekStart)
fn(api.ProgressResponse{Status: "verifying conversion"})
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
newLayer, err := NewLayer(temp, layer.MediaType)
if err != nil {
return nil, err
}
@@ -620,7 +619,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
var layers []*layerGGML
fn(api.ProgressResponse{Status: "parsing GGUF"})
blobPath, err := manifest.BlobsPath(digest)
blobPath, err := GetBlobsPath(digest)
if err != nil {
return nil, err
}
@@ -655,7 +654,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
mediatype = "application/vnd.ollama.image.projector"
}
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
@@ -666,8 +665,8 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
return detectChatTemplate(layers)
}
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
func removeLayer(layers []Layer, mediatype string) []Layer {
return slices.DeleteFunc(layers, func(layer Layer) bool {
if layer.MediaType != mediatype {
return false
}
@@ -681,7 +680,7 @@ func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
})
}
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
func setTemplate(layers []Layer, t string) ([]Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.template")
if _, err := template.Parse(t); err != nil {
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
@@ -691,7 +690,7 @@ func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
}
blob := strings.NewReader(t)
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
layer, err := NewLayer(blob, "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
@@ -700,11 +699,11 @@ func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
return layers, nil
}
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
func setSystem(layers []Layer, s string) ([]Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.system")
if s != "" {
blob := strings.NewReader(s)
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
layer, err := NewLayer(blob, "application/vnd.ollama.image.system")
if err != nil {
return nil, err
}
@@ -713,9 +712,9 @@ func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
return layers, nil
}
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
func setLicense(layers []Layer, l string) ([]Layer, error) {
blob := strings.NewReader(l)
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
layer, err := NewLayer(blob, "application/vnd.ollama.image.license")
if err != nil {
return nil, err
}
@@ -723,7 +722,7 @@ func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
return layers, nil
}
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
if p == nil {
p = make(map[string]any)
}
@@ -732,7 +731,7 @@ func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer,
continue
}
digestPath, err := manifest.BlobsPath(layer.Digest)
digestPath, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -766,7 +765,7 @@ func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer,
if err := json.NewEncoder(&b).Encode(p); err != nil {
return nil, err
}
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}
@@ -774,7 +773,7 @@ func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer,
return layers, nil
}
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
// this leaves the old messages intact if no new messages were specified
// which may not be the correct behaviour
if len(m) == 0 {
@@ -787,7 +786,7 @@ func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, er
if err := json.NewEncoder(&b).Encode(m); err != nil {
return nil, err
}
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return nil, err
}
@@ -795,7 +794,7 @@ func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, er
return layers, nil
}
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
digests := make([]string, len(layers))
for i, layer := range layers {
digests[i] = layer.Digest
@@ -806,7 +805,7 @@ func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifes
if err := json.NewEncoder(&b).Encode(config); err != nil {
return nil, err
}
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return nil, err
}

View File

@@ -10,7 +10,6 @@ import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
)
func TestConvertFromSafetensors(t *testing.T) {
@@ -18,7 +17,7 @@ func TestConvertFromSafetensors(t *testing.T) {
// Helper function to create a new layer and return its digest
makeTemp := func(content string) string {
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
if err != nil {
t.Fatalf("Failed to create layer: %v", err)
}

View File

@@ -24,8 +24,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
const maxRetries = 6
@@ -458,7 +456,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
}
type downloadOpts struct {
n model.Name
mp ModelPath
digest string
regOpts *registryOptions
fn func(api.ProgressResponse)
@@ -467,10 +465,10 @@ type downloadOpts struct {
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
if opts.digest == "" {
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
}
fp, err := manifest.BlobsPath(opts.digest)
fp, err := GetBlobsPath(opts.digest)
if err != nil {
return false, err
}
@@ -494,8 +492,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
download := data.(*blobDownload)
if !ok {
requestURL := opts.n.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
requestURL := opts.mp.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest)
return false, err

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -23,7 +24,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/gguf"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
@@ -274,22 +274,44 @@ func (m *Model) String() string {
return modelfile.String()
}
func GetManifest(mp ModelPath) (*Manifest, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
return nil, "", err
}
f, err := os.Open(fp)
if err != nil {
return nil, "", err
}
defer f.Close()
sha256sum := sha256.New()
var manifest Manifest
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
return nil, "", err
}
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
}
func GetModel(name string) (*Model, error) {
n := model.ParseName(name)
mf, err := manifest.ParseNamedManifest(n)
mp := ParseModelPath(name)
manifest, digest, err := GetManifest(mp)
if err != nil {
return nil, err
}
m := &Model{
Name: n.String(),
ShortName: n.DisplayShortest(),
Digest: mf.Digest(),
model := &Model{
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
Template: template.DefaultTemplate,
}
if mf.Config.Digest != "" {
filename, err := manifest.BlobsPath(mf.Config.Digest)
if manifest.Config.Digest != "" {
filename, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return nil, err
}
@@ -300,29 +322,29 @@ func GetModel(name string) (*Model, error) {
}
defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err
}
}
for _, layer := range mf.Layers {
filename, err := manifest.BlobsPath(layer.Digest)
for _, layer := range manifest.Layers {
filename, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
switch layer.MediaType {
case "application/vnd.ollama.image.model":
m.ModelPath = filename
m.ParentModel = layer.From
model.ModelPath = filename
model.ParentModel = layer.From
case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter":
m.AdapterPaths = append(m.AdapterPaths, filename)
model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
m.ProjectorPaths = append(m.ProjectorPaths, filename)
model.ProjectorPaths = append(model.ProjectorPaths, filename)
case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
@@ -330,7 +352,7 @@ func GetModel(name string) (*Model, error) {
return nil, err
}
m.Template, err = template.Parse(string(bts))
model.Template, err = template.Parse(string(bts))
if err != nil {
return nil, err
}
@@ -340,7 +362,7 @@ func GetModel(name string) (*Model, error) {
return nil, err
}
m.System = string(bts)
model.System = string(bts)
case "application/vnd.ollama.image.params":
params, err := os.Open(filename)
if err != nil {
@@ -349,7 +371,7 @@ func GetModel(name string) (*Model, error) {
defer params.Close()
// parse model options parameters into a map so that we can see which fields have been specified explicitly
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
return nil, err
}
case "application/vnd.ollama.image.messages":
@@ -359,7 +381,7 @@ func GetModel(name string) (*Model, error) {
}
defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
return nil, err
}
case "application/vnd.ollama.image.license":
@@ -367,11 +389,11 @@ func GetModel(name string) (*Model, error) {
if err != nil {
return nil, err
}
m.License = append(m.License, string(bts))
model.License = append(model.License, string(bts))
}
}
return m, nil
return model, nil
}
func CopyModel(src, dst model.Name) error {
@@ -386,7 +408,7 @@ func CopyModel(src, dst model.Name) error {
return nil
}
manifests, err := manifest.Path()
manifests, err := GetManifestPath()
if err != nil {
return err
}
@@ -415,7 +437,7 @@ func CopyModel(src, dst model.Name) error {
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
manifests, err := manifest.Manifests(true)
manifests, err := Manifests(true)
if err != nil {
return err
}
@@ -430,7 +452,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
// only delete the files which are still in the deleteMap
for k := range deleteMap {
fp, err := manifest.BlobsPath(k)
fp, err := GetBlobsPath(k)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
continue
@@ -446,7 +468,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
func PruneLayers() error {
deleteMap := make(map[string]struct{})
p, err := manifest.BlobsPath("")
p, err := GetBlobsPath("")
if err != nil {
return err
}
@@ -461,9 +483,9 @@ func PruneLayers() error {
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
_, err := manifest.BlobsPath(name)
_, err := GetBlobsPath(name)
if err != nil {
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
if errors.Is(err, ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
@@ -488,30 +510,63 @@ func PruneLayers() error {
return nil
}
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
n := model.ParseName(name)
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
if n.ProtocolScheme == "http" && !regOpts.Insecure {
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
}
mf, err := manifest.ParseNamedManifest(n)
manifest, _, err := GetManifest(mp)
if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}
var layers []manifest.Layer
layers = append(layers, mf.Layers...)
if mf.Config.Digest != "" {
layers = append(layers, mf.Config)
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := manifest.PathForName(n)
manifestPath, err := mp.GetManifestPath()
if err != nil {
return err
}
@@ -519,7 +574,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if err != nil {
return err
}
if err := pushWithTransfer(ctx, n, layers, manifestJSON, regOpts, fn); err != nil {
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
@@ -527,17 +582,17 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
for _, layer := range layers {
if err := uploadBlob(ctx, n, layer, regOpts, fn); err != nil {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err
}
}
fn(api.ProgressResponse{Status: "pushing manifest"})
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
manifestJSON, err := json.Marshal(mf)
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
@@ -556,44 +611,44 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
n := model.ParseName(name)
mp := ParseModelPath(name)
// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
existingMf, err := manifest.ParseNamedManifest(n)
manifest, _, err := GetManifest(mp)
if errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil {
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
} else {
for _, l := range existingMf.Layers {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = struct{}{}
}
if existingMf.Config.Digest != "" {
deleteMap[existingMf.Config.Digest] = struct{}{}
if manifest.Config.Digest != "" {
deleteMap[manifest.Config.Digest] = struct{}{}
}
}
if n.ProtocolScheme == "http" && !regOpts.Insecure {
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
}
fn(api.ProgressResponse{Status: "pulling manifest"})
mf, err := pullModelManifest(ctx, n, regOpts)
manifest, err = pullModelManifest(ctx, mp, regOpts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}
var layers []manifest.Layer
layers = append(layers, mf.Layers...)
if mf.Config.Digest != "" {
layers = append(layers, mf.Config)
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
if err := pullWithTransfer(ctx, n, layers, mf, regOpts, fn); err != nil {
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
@@ -603,7 +658,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
n: n,
mp: mp,
digest: layer.Digest,
regOpts: regOpts,
fn: fn,
@@ -622,7 +677,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
fp, err := manifest.BlobsPath(layer.Digest)
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
@@ -637,16 +692,16 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, layer := range layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, mf.Config.Digest)
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(mf)
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := manifest.PathForName(n)
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
@@ -673,9 +728,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
// hasTensorLayers checks if any layer has tensor media type.
func hasTensorLayers(layers []manifest.Layer) bool {
func hasTensorLayers(layers []Layer) bool {
for _, layer := range layers {
if layer.MediaType == manifest.MediaTypeImageTensor {
if layer.MediaType == MediaTypeImageTensor {
return true
}
}
@@ -683,7 +738,7 @@ func hasTensorLayers(layers []manifest.Layer) bool {
}
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, mf *manifest.Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
@@ -692,12 +747,12 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
}
}
destDir, err := manifest.BlobsPath("")
destDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := n.BaseURL()
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
@@ -729,7 +784,7 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
Blobs: blobs,
BaseURL: baseURL,
DestDir: destDir,
Repository: n.DisplayNamespaceModel(),
Repository: mp.GetNamespaceRepository(),
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
@@ -740,12 +795,12 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(mf)
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := manifest.PathForName(n)
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
@@ -757,7 +812,7 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
}
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
@@ -767,12 +822,12 @@ func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
}
}
srcDir, err := manifest.BlobsPath("")
srcDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := n.BaseURL()
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
@@ -809,13 +864,13 @@ func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
GetToken: getToken,
Logger: slog.Default(),
Manifest: manifestJSON,
ManifestRef: n.Tag,
Repository: n.DisplayNamespaceModel(),
ManifestRef: mp.Tag,
Repository: mp.GetNamespaceRepository(),
})
}
func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
requestURL := n.BaseURL().JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
@@ -825,7 +880,7 @@ func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptio
}
defer resp.Body.Close()
var m manifest.Manifest
var m Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err
}
@@ -987,7 +1042,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
fp, err := manifest.BlobsPath(digest)
fp, err := GetBlobsPath(digest)
if err != nil {
return err
}

View File

@@ -1,4 +1,4 @@
package manifest
package server
import (
"crypto/sha256"
@@ -14,7 +14,7 @@ type Layer struct {
Size int64 `json:"size"`
From string `json:"from,omitempty"`
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
Status string `json:"-"`
status string
}
const (
@@ -22,7 +22,7 @@ const (
)
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
blobs, err := BlobsPath("")
blobs, err := GetBlobsPath("")
if err != nil {
return Layer{}, err
}
@@ -45,7 +45,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
}
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
blob, err := BlobsPath(digest)
blob, err := GetBlobsPath(digest)
if err != nil {
return Layer{}, err
}
@@ -65,7 +65,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
MediaType: mediatype,
Digest: digest,
Size: n,
Status: fmt.Sprintf("%s %s", status, digest),
status: fmt.Sprintf("%s %s", status, digest),
}, nil
}
@@ -74,7 +74,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
return Layer{}, errors.New("creating new layer from layer with empty digest")
}
blob, err := BlobsPath(digest)
blob, err := GetBlobsPath(digest)
if err != nil {
return Layer{}, err
}
@@ -89,7 +89,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
Digest: digest,
Size: fi.Size(),
From: from,
Status: fmt.Sprintf("using existing layer %s", digest),
status: fmt.Sprintf("using existing layer %s", digest),
}, nil
}
@@ -98,7 +98,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
return nil, errors.New("opening layer with empty digest")
}
blob, err := BlobsPath(l.Digest)
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return nil, err
}
@@ -126,7 +126,7 @@ func (l *Layer) Remove() error {
}
}
blob, err := BlobsPath(l.Digest)
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return err
}

View File

@@ -1,9 +1,10 @@
package manifest
package server
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
@@ -32,38 +33,12 @@ func (m *Manifest) Size() (size int64) {
return
}
func (m *Manifest) Digest() string {
return m.digest
}
func (m *Manifest) FileInfo() os.FileInfo {
return m.fi
}
// ReadConfigJSON reads and unmarshals a config layer as JSON.
func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
for _, layer := range m.Layers {
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
blobPath, err := BlobsPath(layer.Digest)
if err != nil {
return err
}
data, err := os.ReadFile(blobPath)
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
}
return fmt.Errorf("config %q not found in manifest", configPath)
}
func (m *Manifest) Remove() error {
if err := os.Remove(m.filepath); err != nil {
return err
}
manifests, err := Path()
manifests, err := GetManifestPath()
if err != nil {
return err
}
@@ -95,11 +70,11 @@ func (m *Manifest) RemoveLayers() error {
if _, used := inUse[layer.Digest]; used {
continue
}
blob, err := BlobsPath(layer.Digest)
blob, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(blob); os.IsNotExist(err) {
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
@@ -114,7 +89,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, model.Unqualified(n)
}
manifests, err := Path()
manifests, err := GetManifestPath()
if err != nil {
return nil, err
}
@@ -146,7 +121,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
}
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
manifests, err := Path()
manifests, err := GetManifestPath()
if err != nil {
return err
}
@@ -173,7 +148,7 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
}
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
manifests, err := Path()
manifests, err := GetManifestPath()
if err != nil {
return nil, err
}

View File

@@ -1,4 +1,4 @@
package manifest
package server
import (
"encoding/json"

View File

@@ -13,7 +13,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
@@ -21,19 +20,19 @@ import (
var intermediateBlobs map[string]string = make(map[string]string)
type layerGGML struct {
manifest.Layer
Layer
*ggml.GGML
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
m, err := manifest.ParseNamedManifest(name)
m, err := ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}
m, err = manifest.ParseNamedManifest(name)
m, err = ParseNamedManifest(name)
if err != nil {
return nil, err
}
@@ -42,7 +41,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
}
for _, layer := range m.Layers {
layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil {
return nil, err
}
@@ -51,7 +50,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
case "application/vnd.ollama.image.model",
"application/vnd.ollama.image.projector",
"application/vnd.ollama.image.adapter":
blobpath, err := manifest.BlobsPath(layer.Digest)
blobpath, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -82,12 +81,12 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err, "template", s)
} else {
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
layers = append(layers, &layerGGML{layer, nil})
if t.Parameters != nil {
@@ -96,7 +95,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
return nil, err
}
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}

146
server/modelpath.go Normal file
View File

@@ -0,0 +1,146 @@
package server
import (
"errors"
"fmt"
"io/fs"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
type ModelPath struct {
ProtocolScheme string
Registry string
Namespace string
Repository string
Tag string
}
const (
DefaultRegistry = "registry.ollama.ai"
DefaultNamespace = "library"
DefaultTag = "latest"
DefaultProtocolScheme = "https"
)
var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidDigestFormat = errors.New("invalid digest format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
ErrModelPathInvalid = errors.New("invalid model path")
)
func ParseModelPath(name string) ModelPath {
mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "",
Tag: DefaultTag,
}
before, after, found := strings.Cut(name, "://")
if found {
mp.ProtocolScheme = before
name = after
}
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
case 2:
mp.Namespace = parts[0]
mp.Repository = parts[1]
case 1:
mp.Repository = parts[0]
}
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
mp.Repository = repo
mp.Tag = tag
}
return mp
}
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}
func (mp ModelPath) GetFullTagname() string {
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
func (mp ModelPath) GetShortTagname() string {
if mp.Registry == DefaultRegistry {
if mp.Namespace == DefaultNamespace {
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
name := model.Name{
Host: mp.Registry,
Namespace: mp.Namespace,
Model: mp.Repository,
Tag: mp.Tag,
}
if !name.IsValid() {
return "", fs.ErrNotExist
}
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
}
func (mp ModelPath) BaseURL() *url.URL {
return &url.URL{
Scheme: mp.ProtocolScheme,
Host: mp.Registry,
}
}
func GetManifestPath() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}
func GetBlobsPath(digest string) (string, error) {
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(envconfig.Models(), "blobs", digest)
dirPath := filepath.Dir(path)
if digest == "" {
dirPath = path
}
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
}
return path, nil
}

153
server/modelpath_test.go Normal file
View File

@@ -0,0 +1,153 @@
package server
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetBlobsPath(t *testing.T) {
// GetBlobsPath expects an actual directory to exist
tempDir := t.TempDir()
tests := []struct {
name string
digest string
expected string
err error
}{
{
"empty digest",
"",
filepath.Join(tempDir, "blobs"),
nil,
},
{
"valid with colon",
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"valid with dash",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"digest too short",
"sha256-45640291",
"",
ErrInvalidDigestFormat,
},
{
"digest too long",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
"",
ErrInvalidDigestFormat,
},
{
"digest invalid chars",
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
"",
ErrInvalidDigestFormat,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv("OLLAMA_MODELS", tempDir)
got, err := GetBlobsPath(tc.digest)
require.ErrorIs(t, tc.err, err, tc.name)
assert.Equal(t, tc.expected, got, tc.name)
})
}
}
func TestParseModelPath(t *testing.T) {
tests := []struct {
name string
arg string
want ModelPath
}{
{
"full path https",
"https://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"full path http",
"http://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "http",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no protocol",
"example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no registry",
"ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no namespace",
"repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: "tag",
},
},
{
"no tag",
"repo",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: DefaultTag,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ParseModelPath(tc.arg)
if got != tc.want {
t.Errorf("got: %q want: %q", got, tc.want)
}
})
}
}

View File

@@ -198,8 +198,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
name := t.Name
quantize := strings.HasSuffix(name, "weight")
// don't quantize vision encoder tensors (named with "v." prefix)
quantize = quantize && !strings.HasPrefix(name, "v.")
// don't quantize vision stuff
quantize = quantize && (!strings.Contains(name, "v.") || strings.Contains(name, "_v."))
quantize = quantize && !strings.Contains(name, "mm.")
// quantize only 2D and 3D tensors (experts)
@@ -219,9 +219,6 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
// NOTE: can't use LLM_TN here because the layer number is not known
quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight")
// do not quantize LFM2's shortconv kernel weights
quantize = quantize && !strings.Contains(name, "shortconv.conv.weight")
// do not quantize RWKV's time_mix_first tensors
quantize = quantize && !strings.Contains(name, "time_mix_first.weight")
quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")

View File

@@ -39,7 +39,6 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/middleware"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/model/renderers"
@@ -221,6 +220,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Handle image generation models
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
s.handleImageGenerate(c, req, name.String(), checkpointStart)
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
@@ -316,7 +321,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// expire the runner if unload is requested (empty prompt, keep alive is 0)
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m)
@@ -330,12 +335,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Handle image generation models
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
s.handleImageGenerate(c, req, name.String(), checkpointStart)
return
}
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
@@ -975,7 +974,7 @@ func (s *Server) PushHandler(c *gin.Context) {
// is.
func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name
existing, err := manifest.Manifests(true)
existing, err := Manifests(true)
if err != nil {
return zero, err
}
@@ -1019,7 +1018,7 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return
}
m, err := manifest.ParseNamedManifest(n)
m, err := ParseNamedManifest(n)
if err != nil {
switch {
case os.IsNotExist(err):
@@ -1081,7 +1080,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, model.Unqualified(name)
return nil, ErrModelPathInvalid
}
name, err := getExistingName(name)
if err != nil {
@@ -1113,7 +1112,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For safetensors LLM models (experimental), populate details from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
modelDetails.Family = arch
}
@@ -1122,7 +1121,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
// Get torch_dtype directly from config.json for quantization level
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
@@ -1136,7 +1135,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}
mf, err := manifest.ParseNamedManifest(name)
manifest, err := ParseNamedManifest(name)
if err != nil {
return nil, err
}
@@ -1148,11 +1147,8 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
Details: modelDetails,
Messages: msgs,
Capabilities: m.Capabilities(),
ModifiedAt: mf.FileInfo().ModTime(),
ModifiedAt: manifest.fi.ModTime(),
Requires: m.Config.Requires,
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
// default we return an empty map.
ModelInfo: make(map[string]any),
}
if m.Config.RemoteHost != "" {
@@ -1215,7 +1211,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
resp.Tensors = tensors
}
}
@@ -1224,12 +1220,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For safetensors LLM models (experimental), populate ModelInfo from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
resp.ModelInfo = info
}
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
resp.Tensors = tensors
}
}
@@ -1286,7 +1282,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
}
func (s *Server) ListHandler(c *gin.Context) {
ms, err := manifest.Manifests(true)
ms, err := Manifests(true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1317,8 +1313,8 @@ func (s *Server) ListHandler(c *gin.Context) {
RemoteModel: cf.RemoteModel,
RemoteHost: cf.RemoteHost,
Size: m.Size(),
Digest: m.Digest(),
ModifiedAt: m.FileInfo().ModTime(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
@@ -1377,7 +1373,7 @@ func (s *Server) CopyHandler(c *gin.Context) {
}
func (s *Server) HeadBlobHandler(c *gin.Context) {
path, err := manifest.BlobsPath(c.Param("digest"))
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1393,7 +1389,7 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
func (s *Server) CreateBlobHandler(c *gin.Context) {
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
p, err := manifest.BlobsPath(ib)
p, err := GetBlobsPath(ib)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1411,7 +1407,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
}
}
path, err := manifest.BlobsPath(c.Param("digest"))
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1429,7 +1425,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return
}
layer, err := manifest.NewLayer(c.Request.Body, "")
layer, err := NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1629,7 +1625,7 @@ func Serve(ln net.Listener) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
slog.Info("server config", "env", envconfig.Values())
blobsDir, err := manifest.BlobsPath("")
blobsDir, err := GetBlobsPath("")
if err != nil {
return err
}
@@ -1638,7 +1634,7 @@ func Serve(ln net.Listener) error {
}
if !envconfig.NoPrune() {
if _, err := manifest.Manifests(false); err != nil {
if _, err := Manifests(false); err != nil {
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
} else {
// clean up unused layers and manifests
@@ -1646,12 +1642,12 @@ func Serve(ln net.Listener) error {
return err
}
manifestsPath, err := manifest.Path()
manifestsPath, err := GetManifestPath()
if err != nil {
return err
}
if err := manifest.PruneDirectory(manifestsPath); err != nil {
if err := PruneDirectory(manifestsPath); err != nil {
return err
}
}

View File

@@ -25,7 +25,6 @@ import (
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
@@ -224,15 +223,15 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
mf, err := manifest.ParseNamedManifest(model.ParseName("child"))
manifest, err := ParseNamedManifest(model.ParseName("child"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
if mf.Config.Digest == "" {
if manifest.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for child manifest")
}
configPath, err := manifest.BlobsPath(mf.Config.Digest)
configPath, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
@@ -94,13 +93,13 @@ func TestDeleteDuplicateLayers(t *testing.T) {
t.Fatal(err)
}
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
// create a manifest with duplicate layers
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
if err := WriteManifest(n, config, []Layer{config}); err != nil {
t.Fatal(err)
}

View File

@@ -2101,95 +2101,3 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
}
})
}
func TestGenerateUnload(t *testing.T) {
gin.SetMode(gin.TestMode)
var loadFnCalled bool
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mockRunner{}),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
loadFnCalled = true
req.successCh <- &runnerRef{llama: &mockRunner{}}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
Files: map[string]string{"file.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("unload with empty prompt and keepalive 0", func(t *testing.T) {
loadFnCalled = false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "",
KeepAlive: &api.Duration{Duration: 0},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp api.GenerateResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.DoneReason != "unload" {
t.Errorf("expected done_reason 'unload', got %q", resp.DoneReason)
}
if !resp.Done {
t.Error("expected done to be true")
}
if loadFnCalled {
t.Error("expected model NOT to be loaded for unload request, but loadFn was called")
}
})
}

View File

@@ -21,14 +21,12 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
var blobUploadManager sync.Map
type blobUpload struct {
manifest.Layer
Layer
Total int64
Completed atomic.Int64
@@ -53,7 +51,7 @@ const (
)
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
p, err := manifest.BlobsPath(b.Digest)
p, err := GetBlobsPath(b.Digest)
if err != nil {
return err
}
@@ -61,7 +59,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
if b.From != "" {
values := requestURL.Query()
values.Add("mount", b.Digest)
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
requestURL.RawQuery = values.Encode()
}
@@ -130,7 +128,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
p, err := manifest.BlobsPath(b.Digest)
p, err := GetBlobsPath(b.Digest)
if err != nil {
b.err = err
return
@@ -366,9 +364,9 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
switch {
@@ -390,8 +388,8 @@ func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *r
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
upload := data.(*blobUpload)
if !ok {
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
blobUploadManager.Delete(layer.Digest)
return err

View File

@@ -1,8 +1,10 @@
package model
package tokenizers
import (
"cmp"
"fmt"
"iter"
"log/slog"
"slices"
"strings"
@@ -16,7 +18,7 @@ type BytePairEncoding struct {
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
var _ Tokenizer = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
@@ -243,6 +245,14 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
type lazyIdsString struct {
ids []int32
}
func (l lazyIdsString) LogValue() slog.Value {
return slog.AnyValue(fmt.Sprint(l.ids))
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
@@ -267,6 +277,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
}
}
logutil.Trace("decoded", "string", sb.String(), "from", ids)
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
return sb.String(), nil
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizers
import (
"bufio"
@@ -17,7 +17,7 @@ import (
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
}
}
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizers
import (
"container/heap"
@@ -17,7 +17,7 @@ type SentencePiece struct {
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePiece)(nil)
var _ Tokenizer = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab

View File

@@ -1,4 +1,4 @@
package model
package tokenizers
import (
"log/slog"
@@ -15,7 +15,7 @@ import (
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model"))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,4 +1,4 @@
package model
package tokenizers
const (
TOKEN_TYPE_NORMAL = iota + 1
@@ -9,7 +9,7 @@ const (
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
type Tokenizer interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool

View File

@@ -1,4 +1,4 @@
package model
package tokenizers
import (
"log/slog"

View File

@@ -1,4 +1,4 @@
package model
package tokenizers
import (
"testing"

View File

@@ -1,4 +1,4 @@
package model
package tokenizers
import (
"fmt"
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
" 're", "'re",
)
// Decode implements TextProcessor.
// Decode implements Tokenizer.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
}
}
// Encode implements TextProcessor.
// Encode implements Tokenizer.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
// Is implements TextProcessor.
// Is implements Tokenizer.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
// Vocabulary implements Tokenizer.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
var _ Tokenizer = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{

View File

@@ -1,4 +1,4 @@
package model
package tokenizers
import (
"slices"

View File

@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"log/slog"
"net/url"
"path/filepath"
"strings"
)
@@ -36,25 +35,22 @@ func Unqualified(n Name) error {
const MissingPart = "!MISSING!"
const (
defaultHost = "registry.ollama.ai"
defaultNamespace = "library"
defaultTag = "latest"
defaultProtocolScheme = "https"
defaultHost = "registry.ollama.ai"
defaultNamespace = "library"
defaultTag = "latest"
)
// DefaultName returns a name with the default values for the host, namespace,
// tag, and protocol scheme parts. The model and digest parts are empty.
// and tag parts. The model and digest parts are empty.
//
// - The default host is ("registry.ollama.ai")
// - The default namespace is ("library")
// - The default tag is ("latest")
// - The default protocol scheme is ("https")
func DefaultName() Name {
return Name{
Host: defaultHost,
Namespace: defaultNamespace,
Tag: defaultTag,
ProtocolScheme: defaultProtocolScheme,
Host: defaultHost,
Namespace: defaultNamespace,
Tag: defaultTag,
}
}
@@ -91,11 +87,10 @@ func (k partKind) String() string {
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
// is valid.
type Name struct {
Host string
Namespace string
Model string
Tag string
ProtocolScheme string
Host string
Namespace string
Model string
Tag string
}
// ParseName parses and assembles a Name from a name string. The
@@ -165,9 +160,7 @@ func ParseNameBare(s string) Name {
}
scheme, host, ok := strings.Cut(s, "://")
if ok {
n.ProtocolScheme = scheme
} else {
if !ok {
host = scheme
}
n.Host = host
@@ -196,13 +189,12 @@ func ParseNameFromFilepath(s string) (n Name) {
return n
}
// Merge merges the host, namespace, tag, and protocol scheme parts of the two names,
// Merge merges the host, namespace, and tag parts of the two names,
// preferring the non-empty parts of a.
func Merge(a, b Name) Name {
a.Host = cmp.Or(a.Host, b.Host)
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
a.Tag = cmp.Or(a.Tag, b.Tag)
a.ProtocolScheme = cmp.Or(a.ProtocolScheme, b.ProtocolScheme)
return a
}
@@ -313,23 +305,6 @@ func (n Name) EqualFold(o Name) bool {
strings.EqualFold(n.Tag, o.Tag)
}
// BaseURL returns the base URL for the registry.
func (n Name) BaseURL() *url.URL {
return &url.URL{
Scheme: n.ProtocolScheme,
Host: n.Host,
}
}
// DisplayNamespaceModel returns the namespace and model joined by "/".
func (n Name) DisplayNamespaceModel() string {
var b strings.Builder
b.WriteString(n.Namespace)
b.WriteByte('/')
b.WriteString(n.Model)
return b.String()
}
func isValidLen(kind partKind, s string) bool {
switch kind {
case kindHost:

View File

@@ -32,11 +32,10 @@ func TestParseNameParts(t *testing.T) {
{
in: "scheme://host:port/namespace/model:tag",
want: Name{
Host: "host:port",
Namespace: "namespace",
Model: "model",
Tag: "tag",
ProtocolScheme: "scheme",
Host: "host:port",
Namespace: "namespace",
Model: "model",
Tag: "tag",
},
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
},

View File

@@ -12,8 +12,8 @@ import (
"fmt"
"io"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
)
@@ -103,7 +103,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
layer, err := manifest.NewLayer(r, mediaType)
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return create.LayerInfo{}, err
}
@@ -141,13 +141,13 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
}
// Create layer for quantized weight
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -169,7 +169,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -186,7 +186,7 @@ func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quant
// createUnquantizedLayer creates a single tensor layer without quantization.
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
layer, err := manifest.NewLayer(r, manifest.MediaTypeImageTensor)
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -221,15 +221,15 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
}
// Create config layer blob
configLayer, err := manifest.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to manifest.Layer
manifestLayers := make([]manifest.Layer, 0, len(layers))
// Convert LayerInfo to server.Layer
serverLayers := make([]server.Layer, 0, len(layers))
for _, l := range layers {
manifestLayers = append(manifestLayers, manifest.Layer{
serverLayers = append(serverLayers, server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
@@ -243,19 +243,19 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
if err != nil {
return err
}
manifestLayers = append(manifestLayers, modelfileLayers...)
serverLayers = append(serverLayers, modelfileLayers...)
}
return manifest.WriteManifest(name, configLayer, manifestLayers)
return server.WriteManifest(name, configLayer, serverLayers)
}
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
var layers []manifest.Layer
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
var layers []server.Layer
if mf.Template != "" {
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
if err != nil {
return nil, fmt.Errorf("failed to create template layer: %w", err)
}
@@ -263,7 +263,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
}
if mf.System != "" {
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
if err != nil {
return nil, fmt.Errorf("failed to create system layer: %w", err)
}
@@ -271,7 +271,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
}
if mf.License != "" {
layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
if err != nil {
return nil, fmt.Errorf("failed to create license layer: %w", err)
}

View File

@@ -9,7 +9,7 @@ import "github.com/ollama/ollama/x/imagegen/mlx"
// shallow layers change little between consecutive steps, so we can
// cache their outputs and skip recomputation on non-refresh steps.
//
// Supports both single-stream and dual-stream architectures:
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
// - Single-stream: use Get/Set for the single output per layer
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
//
@@ -87,7 +87,7 @@ func (c *StepCache) Set(layer int, arr *mlx.Array) {
}
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
// Used for dual-stream architectures.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Get2(layer int) *mlx.Array {
if layer < len(c.layers2) {
return c.layers2[layer]
@@ -96,7 +96,7 @@ func (c *StepCache) Get2(layer int) *mlx.Array {
}
// Set2 stores a layer output (stream 2), freeing any previous value.
// Used for dual-stream architectures.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
if layer < len(c.layers2) {
if c.layers2[layer] != nil {

View File

@@ -21,6 +21,8 @@ import (
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
@@ -59,11 +61,14 @@ func main() {
listTensors := flag.Bool("list", false, "List tensors only")
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
@@ -161,6 +166,60 @@ func main() {
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImageEdit:
if len(inputImages) == 0 {
log.Fatal("qwen-image-edit requires at least one -input-image")
}
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
// For image editing, use 0 for dimensions to auto-detect from input image
// unless explicitly overridden from defaults
editWidth := int32(0)
editHeight := int32(0)
if *width != 1024 {
editWidth = int32(*width)
}
if *height != 1024 {
editHeight = int32(*height)
}
cfg := &qwen_image_edit.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: editWidth,
Height: editHeight,
Steps: *steps,
Seed: *seed,
}
var img *mlx.Array
img, err = m.EditFromConfig(inputImages, cfg)
if err == nil {
err = saveImageArray(img, *out)
}
case *listTensors:
err = listModelTensors(*modelPath)
default:

View File

@@ -6,9 +6,8 @@ import (
"io"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/ollama/ollama/envconfig"
)
// ManifestLayer represents a layer in the manifest.
@@ -33,15 +32,31 @@ type ModelManifest struct {
BlobDir string
}
// DefaultBlobDir returns the default blob storage directory.
func DefaultBlobDir() string {
return filepath.Join(envconfig.Models(), "blobs")
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
switch runtime.GOOS {
case "darwin":
return filepath.Join(home, ".ollama", "models", "blobs")
case "linux":
return filepath.Join(home, ".ollama", "models", "blobs")
case "windows":
return filepath.Join(home, ".ollama", "models", "blobs")
default:
return filepath.Join(home, ".ollama", "models", "blobs")
}
}
// DefaultManifestDir returns the manifest storage directory.
// Respects OLLAMA_MODELS.
// DefaultManifestDir returns the default manifest storage directory.
func DefaultManifestDir() string {
return filepath.Join(envconfig.Models(), "manifests")
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
return filepath.Join(home, ".ollama", "models", "manifests")
}
// LoadManifest loads a manifest for the given model name.

View File

@@ -1,26 +0,0 @@
package imagegen
import (
"path/filepath"
"testing"
)
func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
modelsDir := filepath.Join(t.TempDir(), "models")
// Simulate packaged/systemd environment
t.Setenv("OLLAMA_MODELS", modelsDir)
t.Setenv("HOME", "/usr/share/ollama")
// Manifest dir must respect OLLAMA_MODELS
wantManifest := filepath.Join(modelsDir, "manifests")
if got := DefaultManifestDir(); got != wantManifest {
t.Fatalf("DefaultManifestDir() = %q, want %q", got, wantManifest)
}
// Blob dir must respect OLLAMA_MODELS
wantBlobs := filepath.Join(modelsDir, "blobs")
if got := DefaultBlobDir(); got != wantBlobs {
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
}
}

View File

@@ -0,0 +1,87 @@
//go:build mlx
package qwen_image
import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := mlx.InitMLX(); err != nil {
fmt.Printf("Skipping qwen_image tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestPipelineOutput runs the full pipeline (integration test).
// Skips if model weights not found. Requires ~50GB VRAM.
func TestPipelineOutput(t *testing.T) {
modelPath := "../../../weights/Qwen-Image-2512"
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + modelPath)
}
// Load model
pm, err := LoadPersistent(modelPath)
if err != nil {
t.Skipf("Skipping: failed to load model: %v", err)
}
// Run 2-step pipeline (minimum for stable scheduler)
cfg := &GenerateConfig{
Prompt: "a cat",
Width: 256,
Height: 256,
Steps: 2,
Seed: 42,
}
output, err := pm.GenerateFromConfig(cfg)
if err != nil {
t.Fatalf("Pipeline failed: %v", err)
}
mlx.Eval(output)
// Verify output shape [1, C, H, W]
shape := output.Shape()
if len(shape) != 4 {
t.Errorf("Expected 4D output, got %v", shape)
}
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
}
// Verify values in expected range [0, 1]
data := output.Data()
minVal, maxVal := float32(1.0), float32(0.0)
for _, v := range data {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
if minVal < -0.1 || maxVal > 1.1 {
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,367 @@
//go:build mlx
// Package qwen_image implements the Qwen-Image diffusion transformer model.
package qwen_image
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
// Layer caching (DeepCache/Learning-to-Cache speedup)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 25)
}
// Model represents a Qwen-Image diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen25VL
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Qwen-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err := tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
m.TextEncoder = &Qwen25VL{}
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
imgSeqLen := pH * pW
// Text encoding
var posEmb, negEmb *mlx.Array
{
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
if useCFG {
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
// Init latents [B, C, T, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs)
}
// Layer cache for DeepCache/Learning-to-Cache speedup
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
latents2D := mlx.Squeeze(latents, 2)
patches := PackLatents(latents2D, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
// CFG Batching: single forward pass with batch=2
// Note: layer caching with CFG is not supported yet (would need 2 caches)
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
L := batchedOutput.Shape()[1]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
combPred := mlx.Add(negOutput, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
} else if stepCache != nil {
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
} else {
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
}
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep cached arrays alive across cleanup
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode (Decode manages its own pools for staged memory)
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
// Post-process: squeeze temporal dim and rescale to [0, 1]
{
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
mlx.Eval(decoded)
}
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
// Use m := &Model{}; m.Load(path) instead.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}

View File

@@ -0,0 +1,218 @@
//go:build mlx
package qwen_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.5
MaxShift float32 `json:"max_shift"` // 0.9
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
}
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.5,
MaxShift: 0.9, // Matches scheduler_config.json
BaseImageSeqLen: 256,
MaxImageSeqLen: 8192,
ShiftTerminal: 0.02,
UseDynamicShift: true,
}
}
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32
Sigmas []float32
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// CalculateShift computes the dynamic shift based on image sequence length
// This matches Python's calculate_shift function
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
b := baseShift - m*float32(baseSeqLen)
mu := float32(imageSeqLen)*m + b
return mu
}
// SetTimesteps sets up the scheduler for the given number of inference steps
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
// 1. Create sigmas from sigma_max to sigma_min (linspace)
// 2. Apply time_shift with mu (if dynamic shifting)
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
s.NumSteps = numSteps
// Calculate mu for dynamic shifting
var mu float32
if s.Config.UseDynamicShift {
mu = CalculateShift(
imageSeqLen,
s.Config.BaseImageSeqLen,
s.Config.MaxImageSeqLen,
s.Config.BaseShift,
s.Config.MaxShift,
)
}
// Step 1: Create sigmas from 1.0 to 1/num_steps
// Python (pipeline_qwenimage.py:639):
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
sigmas := make([]float32, numSteps)
sigmaMax := float32(1.0)
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
if numSteps == 1 {
sigmas[0] = sigmaMax
} else {
for i := 0; i < numSteps; i++ {
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
}
}
// Step 2: Apply time shift if using dynamic shifting
if s.Config.UseDynamicShift && mu != 0 {
for i := range sigmas {
sigmas[i] = s.timeShift(mu, sigmas[i])
}
}
// Step 3: Apply stretch_shift_to_terminal
if s.Config.ShiftTerminal > 0 {
sigmas = s.stretchShiftToTerminal(sigmas)
}
// Step 4: Append terminal sigma (0) and store
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
// before passing to transformer. We skip both steps and just use sigmas directly.
s.Sigmas = make([]float32, numSteps+1)
s.Timesteps = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
s.Sigmas[i] = sigmas[i]
s.Timesteps[i] = sigmas[i]
}
s.Sigmas[numSteps] = 0.0
s.Timesteps[numSteps] = 0.0
}
// stretchShiftToTerminal stretches and shifts the timestep schedule
// so the final value equals shift_terminal (matches Python behavior)
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
if len(sigmas) == 0 {
return sigmas
}
// one_minus_z = 1 - t
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
// stretched_t = 1 - (one_minus_z / scale_factor)
lastSigma := sigmas[len(sigmas)-1]
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
if scaleFactor < 1e-6 {
return sigmas
}
result := make([]float32, len(sigmas))
for i, t := range sigmas {
oneMinusZ := 1.0 - t
result[i] = 1.0 - (oneMinusZ / scaleFactor)
}
return result
}
// timeShift applies the dynamic time shift (exponential)
// exp(mu) / (exp(mu) + (1/t - 1))
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
// modelOutput: predicted velocity from the transformer
// sample: current noisy sample
// timestepIdx: current timestep index
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
// Get current and next sigma
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 to avoid precision issues (matches Python diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to original dtype
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
}
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
// This matches how Python diffusers generates noise - directly in packed space.
// Generating in unpacked format and then packing produces different spatial
// correlation structure, which affects model output quality.
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
shape := []int32{batchSize, seqLen, channels}
return mlx.RandomNormal(shape, uint64(seed))
}
// GetLatentShape returns the latent shape for a given image size
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
func GetLatentShape(batchSize, height, width int32) []int32 {
latentH := height / 8
latentW := width / 8
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
}
// GetPatchedLatentShape returns the patchified latent shape
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
latentH := height / 8
latentW := width / 8
pH := latentH / patchSize
pW := latentW / patchSize
inChannels := int32(64) // 16 * patch_size^2
return []int32{batchSize, pH * pW, inChannels}
}

View File

@@ -0,0 +1,135 @@
//go:build mlx
package qwen_image
import (
"math"
"testing"
)
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
// Golden values generated via:
//
// python3 -c "
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
// import numpy as np
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
// sigmas = np.linspace(1.0, 1.0/30, 30)
// s.set_timesteps(sigmas=sigmas, mu=mu)
// print(s.sigmas.numpy())"
func TestSchedulerSetTimesteps(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Golden values from Python diffusers (first 3, last 3 before terminal)
wantFirst := []float32{1.000000, 0.982251, 0.963889}
wantLast := []float32{0.142924, 0.083384, 0.020000}
// Check first 3
for i, want := range wantFirst {
got := scheduler.Sigmas[i]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
}
}
// Check last 3 (indices 27, 28, 29)
for i, want := range wantLast {
idx := 27 + i
got := scheduler.Sigmas[idx]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
}
}
// Check terminal is 0
if scheduler.Sigmas[30] != 0.0 {
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
}
// Check length
if len(scheduler.Sigmas) != 31 {
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
}
}
// TestSchedulerProperties tests mathematical invariants of the scheduler.
func TestSchedulerProperties(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Property: sigmas monotonically decreasing
for i := 1; i < len(scheduler.Sigmas); i++ {
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
}
}
// Property: first sigma should be ~1.0 (with time shift)
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
}
// Property: terminal sigma should be exactly 0
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
}
// Property: last non-terminal sigma should be shift_terminal (0.02)
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
if abs32(lastNonTerminal-0.02) > 1e-5 {
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
}
// Property: length = steps + 1
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
t.Errorf("sigmas length should be steps+1: got %d, want %d",
len(scheduler.Sigmas), scheduler.NumSteps+1)
}
}
// TestCalculateShift verifies the mu calculation against Python reference.
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
func TestCalculateShift(t *testing.T) {
cases := []struct {
imgSeqLen int32
want float32
}{
{256, 0.5}, // base case
{8192, 0.9}, // max case
{4096, 0.6935}, // middle case (rounded)
}
for _, c := range cases {
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
if abs32(got-c.want) > 0.001 {
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
}
}
}
// TestSchedulerStep verifies the Euler step formula.
func TestSchedulerStep(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Verify dt calculation for first step
sigma0 := scheduler.Sigmas[0]
sigma1 := scheduler.Sigmas[1]
expectedDt := sigma1 - sigma0
// dt should be negative (sigmas decrease)
if expectedDt >= 0 {
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
}
}
func abs32(x float32) float32 {
return float32(math.Abs(float64(x)))
}

View File

@@ -0,0 +1,174 @@
//go:build mlx
package qwen_image
import (
"encoding/json"
"math"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TinyTextEncoderConfig holds config for the tiny test text encoder
type TinyTextEncoderConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
MRoPESection []int32 `json:"mrope_section"`
}
// loadTinyTextEncoder loads the tiny text encoder from testdata
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
t.Helper()
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
// Load config
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
if err != nil {
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
}
var tinyCfg TinyTextEncoderConfig
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
t.Fatalf("Failed to parse config: %v", err)
}
// Create encoder config (using Qwen25VLConfig)
cfg := &Qwen25VLConfig{
HiddenSize: tinyCfg.HiddenSize,
NumHiddenLayers: tinyCfg.NumHiddenLayers,
IntermediateSize: tinyCfg.IntermediateSize,
NumAttentionHeads: tinyCfg.NumAttentionHeads,
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
VocabSize: tinyCfg.VocabSize,
RMSNormEps: tinyCfg.RMSNormEps,
RopeTheta: tinyCfg.RopeTheta,
HeadDim: tinyCfg.HeadDim,
MRoPESection: tinyCfg.MRoPESection,
}
// Load weights
weights, err := safetensors.LoadModelWeights(testdataDir)
if err != nil {
t.Fatalf("Failed to load weights: %v", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
t.Fatalf("Failed to bulk load weights: %v", err)
}
// Build encoder
embedding, err := weights.Get("model.embed_tokens.weight")
if err != nil {
t.Fatalf("Failed to get embedding: %v", err)
}
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
block, err := newVLTextBlock(weights, int(i), cfg)
if err != nil {
t.Fatalf("Failed to load block %d: %v", i, err)
}
blocks[i] = block
}
finalNorm, err := weights.Get("model.norm.weight")
if err != nil {
t.Fatalf("Failed to get final norm: %v", err)
}
encoder := &Qwen25VL{
Config: cfg,
Embedding: embedding,
Blocks: blocks,
FinalNorm: finalNorm,
HasVision: false, // Text-only mode
}
return encoder, &tinyCfg
}
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
func TestTextEncoderForward(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// Create test tokens (within vocab range)
tokens := []int32{1, 2, 3, 4, 5}
// Forward pass using EncodeTextOnly
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
// Verify output shape: [batch, seq_len, hidden_size]
wantShape := []int32{1, 5, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
}
// Verify output is finite (not NaN or Inf)
data := out.Data()
for i, v := range data {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
t.Errorf("output[%d] is not finite: %v", i, v)
break
}
}
}
// TestTextEncoderBatch tests batch processing.
func TestTextEncoderBatch(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// For batch test, we'll use EncodeTextOnly with a single sequence
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
tokens := []int32{1, 2, 3}
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
wantShape := []int32{1, 3, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
}
}
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
func TestMRoPEComputation(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
cossin := encoder.computeTextRoPE(10, 1)
mlx.Eval(cossin[0], cossin[1])
// Verify shapes: [3, B, L, head_dim]
wantShape := []int32{3, 1, 10, cfg.HeadDim}
if !slices.Equal(cossin[0].Shape(), wantShape) {
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
}
if !slices.Equal(cossin[1].Shape(), wantShape) {
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
}
// Verify cos/sin values are in valid range [-1, 1]
cosData := cossin[0].Data()
sinData := cossin[1].Data()
for i := 0; i < min(100, len(cosData)); i++ {
if cosData[i] < -1.01 || cosData[i] > 1.01 {
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
}
if sinData[i] < -1.01 || sinData[i] > 1.01 {
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
}
}
}

Some files were not shown because too many files have changed in this diff Show More