Compare commits

..

4 Commits

Author SHA1 Message Date
Jesse Gross
5317202c38 llm: Don't always evict models on CPU-only systems
Model eviction happens when we have at least one other model
loaded and are unable to load all layers into VRAM. However, on
CPU-only systems we can never load layers into VRAM, so this
constantly triggered eviction.

Fixes #13227
2025-12-02 10:58:08 -08:00
Daniel Hiltgen
d771043e88 test: add ministral-3 (#13300) 2025-12-02 09:52:16 -08:00
Daniel Hiltgen
f8f1071818 CUDA: verify CC is supported by target library (#13298) 2025-12-02 09:28:41 -08:00
Patrick Devine
d3e0a0dee4 model: ministral w/ llama4 scaling (#13292)
This change:

* fixes rope scaling in the mistral converter
* updates ministral to include llama4 scaling
* includes a new ministral parser for parsing reasoning and tool calling

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2025-12-01 23:20:14 -08:00
15 changed files with 398 additions and 681 deletions

View File

@@ -11,7 +11,6 @@ linters:
- errorlint
- exptostd
- gocheckcompilerdirectives
- gocritic
- govet
- ineffassign
- intrange

View File

@@ -1,625 +0,0 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "transformers>=4.57.0",
# "jinja2",
# "fastapi",
# "uvicorn",
# "pydantic",
# "requests",
# ]
# ///
"""
Chat Template Testing Tool
Test HuggingFace chat templates against Ollama renderers.
Usage:
# Run predefined test cases against a HuggingFace model
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3
# Compare HuggingFace output with Ollama renderer
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --ollama-model intellect3
# Start server for manual curl testing
uv run cmd/chat_template/chat_template.py --serve
# Show chat template for a model
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --show-template
"""
import argparse
import json
import sys
from typing import Any
from transformers import AutoTokenizer
TEST_CASES = [
{
"name": "basic_user_message",
"messages": [{"role": "user", "content": "Hello!"}],
"tools": None,
},
{
"name": "with_system_message",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
"tools": None,
},
{
"name": "multi_turn_conversation",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
],
"tools": None,
},
{
"name": "with_tools",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the weather?"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {"type": "string", "description": "The city"}
},
},
},
}
],
},
{
"name": "tool_call_and_response",
"messages": [
{"role": "user", "content": "What is the weather in SF?"},
{
"role": "assistant",
"content": "Let me check the weather.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "San Francisco"},
},
}
],
},
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {"type": "string", "description": "The city"}
},
},
},
}
],
},
{
"name": "parallel_tool_calls",
"messages": [
{"role": "user", "content": "Get weather in SF and NYC"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "San Francisco"},
},
},
{
"id": "call_2",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "New York"},
},
},
],
},
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
{"role": "tool", "content": '{"temperature": 55}', "tool_call_id": "call_2"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
],
},
# Thinking tests
{
"name": "assistant_with_thinking",
"messages": [
{"role": "user", "content": "What is 2+2?"},
{
"role": "assistant",
"content": "The answer is 4.",
"thinking": "Let me calculate: 2 + 2 = 4. This is basic arithmetic.",
},
{"role": "user", "content": "And 3+3?"},
],
"tools": None,
},
{
"name": "thinking_with_tool_call",
"messages": [
{"role": "user", "content": "What's the weather in Paris?"},
{
"role": "assistant",
"content": "I'll check the weather for you.",
"thinking": "The user wants to know the weather in Paris. I should call the get_weather function.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "Paris"},
},
}
],
},
{"role": "tool", "content": '{"temperature": 18, "condition": "cloudy"}', "tool_call_id": "call_1"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
],
},
{
"name": "thinking_only_no_content",
"messages": [
{"role": "user", "content": "Think about this silently."},
{
"role": "assistant",
"content": "", # HuggingFace requires content field
"thinking": "I'm thinking about this but won't respond with visible content.",
},
{"role": "user", "content": "What did you think?"},
],
"tools": None,
},
]
# Cache for tokenizers
_tokenizer_cache: dict[str, Any] = {}
def get_tokenizer(model_name: str):
"""Get or create tokenizer for the given model."""
if model_name not in _tokenizer_cache:
print(f"Loading tokenizer for {model_name}...", file=sys.stderr)
_tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(model_name)
return _tokenizer_cache[model_name]
def apply_template(
model: str,
messages: list[dict],
tools: list[dict] | None = None,
) -> str:
"""Apply HuggingFace chat template to messages."""
tokenizer = get_tokenizer(model)
if tools:
return tokenizer.apply_chat_template(
messages,
tools=tools,
tokenize=False,
add_generation_prompt=True,
)
else:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
def get_ollama_prompt(
ollama_model: str,
messages: list[dict],
tools: list[dict] | None = None,
ollama_host: str = "http://localhost:11434",
) -> str | None:
"""Get rendered prompt from Ollama using debug_render_only."""
import requests
# Convert messages to Ollama format
ollama_messages = []
for msg in messages:
ollama_msg = {"role": msg["role"]}
if "content" in msg:
ollama_msg["content"] = msg["content"]
if "thinking" in msg:
ollama_msg["thinking"] = msg["thinking"]
if "tool_calls" in msg:
# Convert tool_calls to Ollama format
tool_calls = []
for tc in msg["tool_calls"]:
tool_call = {
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
}
if "id" in tc:
tool_call["id"] = tc["id"]
tool_calls.append(tool_call)
ollama_msg["tool_calls"] = tool_calls
if "tool_call_id" in msg:
ollama_msg["tool_call_id"] = msg["tool_call_id"]
ollama_messages.append(ollama_msg)
payload = {
"model": ollama_model,
"messages": ollama_messages,
"stream": False,
"_debug_render_only": True,
}
if tools:
payload["tools"] = tools
try:
resp = requests.post(f"{ollama_host}/api/chat", json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
# Field name is _debug_info with underscore prefix
if "_debug_info" in data and "rendered_template" in data["_debug_info"]:
return data["_debug_info"]["rendered_template"]
return None
except requests.exceptions.ConnectionError:
print(f" [ERROR] Cannot connect to Ollama at {ollama_host}", file=sys.stderr)
return None
except Exception as e:
print(f" [ERROR] Ollama request failed: {e}", file=sys.stderr)
return None
def compute_diff(hf_prompt: str, ollama_prompt: str) -> str:
"""Compute a unified diff between HuggingFace and Ollama prompts."""
import difflib
hf_lines = hf_prompt.splitlines(keepends=True)
ollama_lines = ollama_prompt.splitlines(keepends=True)
diff = difflib.unified_diff(
ollama_lines,
hf_lines,
fromfile="Ollama",
tofile="HuggingFace",
lineterm="",
)
return "".join(diff)
def print_test_output(
name: str,
messages: list[dict],
tools: list[dict] | None,
hf_prompt: str,
ollama_prompt: str | None = None,
as_repr: bool = False,
):
"""Print test output in a format suitable for Go test creation and LLM diffing."""
print(f"\n{'='*60}")
print(f"Test: {name}")
print("=" * 60)
print("\n--- Input Messages ---")
print(json.dumps(messages, indent=2))
if tools:
print("\n--- Tools ---")
print(json.dumps(tools, indent=2))
if ollama_prompt is not None:
# Comparison mode
if hf_prompt == ollama_prompt:
print("\n--- Result: MATCH ---")
print("\n--- Prompt (both identical) ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
else:
print("\n--- Result: MISMATCH ---")
print("\n--- HuggingFace Prompt ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
print("\n--- Ollama Prompt ---")
if as_repr:
print(repr(ollama_prompt))
else:
print(ollama_prompt)
print("\n--- Diff (Ollama -> HuggingFace) ---")
diff = compute_diff(hf_prompt, ollama_prompt)
if diff:
print(diff)
else:
print("(no line-level diff, check whitespace)")
else:
# HuggingFace only mode
print("\n--- HuggingFace Prompt ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
print("=" * 60)
def run_tests(
model: str,
as_repr: bool = False,
test_filter: str | None = None,
ollama_model: str | None = None,
ollama_host: str = "http://localhost:11434",
):
"""Run all predefined test cases against a model."""
if ollama_model:
print(f"\nComparing HuggingFace ({model}) vs Ollama ({ollama_model})\n")
else:
print(f"\nRunning tests against: {model}\n")
matches = 0
mismatches = 0
errors = 0
for test_case in TEST_CASES:
name = test_case["name"]
messages = test_case["messages"]
tools = test_case["tools"]
# Filter tests if specified
if test_filter and test_filter.lower() not in name.lower():
continue
try:
hf_prompt = apply_template(model, messages, tools)
ollama_prompt = None
if ollama_model:
ollama_prompt = get_ollama_prompt(
ollama_model, messages, tools, ollama_host
)
if ollama_prompt is None:
errors += 1
elif hf_prompt == ollama_prompt:
matches += 1
else:
mismatches += 1
print_test_output(
name, messages, tools, hf_prompt, ollama_prompt, as_repr=as_repr
)
except Exception as e:
errors += 1
print(f"\n{'='*60}")
print(f"Test: {name} - FAILED")
print(f"--- Input Messages ---")
print(json.dumps(messages, indent=2))
if tools:
print(f"--- Tools ---")
print(json.dumps(tools, indent=2))
print(f"--- Error ---")
print(f"{e}")
print("=" * 60)
# Print summary if comparing
if ollama_model:
total = matches + mismatches + errors
print(f"\n{'='*60}")
print("SUMMARY")
print("=" * 60)
print(f" Total: {total}")
print(f" Matches: {matches}")
print(f" Mismatches: {mismatches}")
print(f" Errors: {errors}")
print("=" * 60)
def show_template(model: str):
"""Show the chat template for a model."""
tokenizer = get_tokenizer(model)
print(f"\nChat template for {model}:\n")
print("-" * 60)
print(tokenizer.chat_template)
print("-" * 60)
def start_server(host: str = "0.0.0.0", port: int = 8000):
"""Start the FastAPI server for manual testing."""
from typing import Optional, List, Dict, Any as TypingAny
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
class Message(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[List[Dict[str, TypingAny]]] = None
tool_call_id: Optional[str] = None
class GeneratePromptRequest(BaseModel):
messages: List[Message]
model: str = "PrimeIntellect/INTELLECT-3"
tools: Optional[List[Dict[str, TypingAny]]] = None
inject_tools_as_functions: bool = False
class GeneratePromptResponse(BaseModel):
prompt: str
model: str
app = FastAPI(title="HuggingFace Prompt Generator", version="1.0.0")
@app.post("/generate-prompt", response_model=GeneratePromptResponse)
async def generate_prompt(request: GeneratePromptRequest):
try:
messages = []
for msg in request.messages:
message_dict = {"role": msg.role}
if msg.content is not None:
message_dict["content"] = msg.content
if msg.tool_calls is not None:
tool_calls = []
for tc in msg.tool_calls:
tc_copy = tc.copy()
if "function" in tc_copy and "arguments" in tc_copy["function"]:
args = tc_copy["function"]["arguments"]
if isinstance(args, str):
try:
tc_copy["function"]["arguments"] = json.loads(args)
except json.JSONDecodeError:
pass
tool_calls.append(tc_copy)
message_dict["tool_calls"] = tool_calls
if msg.tool_call_id is not None:
message_dict["tool_call_id"] = msg.tool_call_id
messages.append(message_dict)
prompt = apply_template(request.model, messages, request.tools)
return GeneratePromptResponse(prompt=prompt, model=request.model)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
print(f"Starting server on http://{host}:{port}")
print("Endpoints:")
print(" POST /generate-prompt - Generate prompt from messages")
print(" GET /health - Health check")
uvicorn.run(app, host=host, port=port)
def main():
parser = argparse.ArgumentParser(
description="HuggingFace Prompt Testing Tool",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--model",
"-m",
type=str,
help="HuggingFace model name (e.g., PrimeIntellect/INTELLECT-3)",
)
parser.add_argument(
"--ollama-model",
"-o",
type=str,
help="Ollama model name to compare against (e.g., qwen3-coder)",
)
parser.add_argument(
"--ollama-host",
type=str,
default="http://localhost:11434",
help="Ollama server URL (default: http://localhost:11434)",
)
parser.add_argument(
"--serve",
"-s",
action="store_true",
help="Start FastAPI server for manual curl testing",
)
parser.add_argument(
"--port",
"-p",
type=int,
default=8000,
help="Server port (default: 8000)",
)
parser.add_argument(
"--show-template",
"-t",
action="store_true",
help="Show the chat template for the model",
)
parser.add_argument(
"--repr",
"-r",
action="store_true",
help="Output prompts as Python repr (shows escape sequences)",
)
parser.add_argument(
"--filter",
"-f",
type=str,
help="Filter tests by name (substring match)",
)
args = parser.parse_args()
if args.serve:
start_server(port=args.port)
elif args.model:
if args.show_template:
show_template(args.model)
else:
run_tests(
args.model,
as_repr=args.repr,
test_filter=args.filter,
ollama_model=args.ollama_model,
ollama_host=args.ollama_host,
)
else:
parser.print_help()
print("\nExample usage:")
print(" uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3")
print(" uv run cmd/chat_template/chat_template.py --model Qwen/Qwen3-Coder-480B-A35B-Instruct --ollama-model qwen3-coder")
print(" uv run cmd/chat_template/chat_template.py --serve")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1430,7 +1430,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
latest.Summary()
}
return &api.Message{Role: role, Content: fullResponse.String()}, nil
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
}
func generate(cmd *cobra.Command, opts runOptions) error {

View File

@@ -29,6 +29,15 @@ type mistral3Model struct {
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
ScalingBeta float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"`
@@ -61,8 +70,13 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
}
// Vision configuration
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers

View File

@@ -33,6 +33,9 @@ func TestVisionModels(t *testing.T) {
// Qwen 3 VL mixture of experts
model: "qwen3-vl:30b",
},
{
model: "ministral-3",
},
}
for _, v := range testCases {

View File

@@ -38,6 +38,7 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"ministral-3",
"qwen3-coder:30b",
"gpt-oss:20b",
"gemma3n:e2b",
@@ -167,6 +168,7 @@ var (
"medllama2",
"megadolphin",
"minicpm-v",
"ministral-3",
"mistral-large",
"mistral-nemo",
"mistral-openorca",
@@ -270,6 +272,7 @@ var (
"mistral",
"qwen2.5",
"qwen2",
"ministral-3",
"mistral-nemo",
"mistral-small",
"mixtral:8x22b",

View File

@@ -874,7 +874,7 @@ func (s *llmServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.Devic
}}
}
gpuLayers, layers := s.buildLayout(systemGPUs, memory, requireFull, backoff)
err := s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
err := s.verifyLayout(systemInfo, systemGPUs, memory, requireFull, gpuLayers, layers)
if err != nil {
return nil, err
}
@@ -943,7 +943,7 @@ func (s *llmServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMe
}
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
// These sizes will only increase as we go through additional iterations and get additional information.
cpuSize := memory.InputWeights + memory.CPU.Graph
var vramSize uint64
@@ -970,8 +970,8 @@ nextLayer:
}
if requireFull {
if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
slog.Info("model requires more memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
if len(systemGPUs) > 0 && gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
slog.Info("model requires more gpu memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
return ErrLoadRequiredFull
}
@@ -998,7 +998,7 @@ nextLayer:
}
}
if gpuLayers.Sum() == 0 {
if len(systemGPUs) > 0 && gpuLayers.Sum() == 0 {
slog.Debug("insufficient VRAM to load any model layers")
}

View File

@@ -26,10 +26,11 @@ func TestLLMServerFitGPU(t *testing.T) {
expectedErr error
}{
{
name: "No GPU",
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{},
name: "No GPU",
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{},
requireFull: true, // Should not try to evict even though we can't load any layers
},
{
name: "Full single GPU",

View File

@@ -509,11 +509,9 @@ func GetVisibleDevicesEnv(l []DeviceInfo) map[string]string {
// to crash at inference time and requires deeper validation before we include
// it in the supported devices list.
func (d DeviceInfo) NeedsInitValidation() bool {
// At this time the only library we know needs a 2nd pass is ROCm since
// rocblas will crash on unsupported devices. We want to find those crashes
// during bootstrap discovery so we can eliminate those GPUs before the user
// tries to run inference on them
return d.Library == "ROCm"
// ROCm: rocblas will crash on unsupported devices.
// CUDA: verify CC is supported by the version of the library
return d.Library == "ROCm" || d.Library == "CUDA"
}
// Set the init validation environment variable

View File

@@ -159,8 +159,9 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
positionsScale := m.getScale(ctx, batch.Positions)
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, positionsScale, batch.Outputs, batch, m.Cache), nil
}
func init() {

View File

@@ -16,6 +16,8 @@ type TextOptions struct {
hiddenSize, numHeads, numKVHeads int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
ropeOrigPosEmbeddings int
ropeScalingBeta float32
}
type TextModel struct {
@@ -34,7 +36,7 @@ type SelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
@@ -49,6 +51,10 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if opts.ropeOrigPosEmbeddings > 0 {
q = q.Mul(ctx, positionsScale)
}
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
@@ -76,11 +82,11 @@ type Layer struct {
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, positionsScale, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
@@ -97,7 +103,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
// image embeddings
@@ -114,25 +120,36 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
hiddenState = layer.Forward(ctx, hiddenState, positions, positionsScale, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}
func (m *TextModel) getScale(ctx ml.Context, positions []int32) ml.Tensor {
posScale := make([]float32, len(positions))
for n, pos := range positions {
interval := math.Floor(float64(pos) / float64(m.ropeOrigPosEmbeddings))
posScale[n] = float32(1.0 + float64(m.ropeScalingBeta)*math.Log(1.0+interval))
}
return ctx.Input().FromFloats(posScale, 1, 1, len(posScale))
}
func newTextModel(c fs.Config) *TextModel {
return &TextModel{
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
ropeScalingBeta: c.Float("rope.scaling_beta"),
},
}
}

136
model/parsers/ministral.go Normal file
View File

@@ -0,0 +1,136 @@
package parsers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
type ministralParserState int
const (
ministralCollectingContent = iota
ministralCollectingThinkingContent
ministralCollectingToolName
ministralCollectingToolArgs
)
type MinistralParser struct {
state ministralParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
currentTool *api.Tool
}
func (p *MinistralParser) HasToolSupport() bool {
return true
}
func (p *MinistralParser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *MinistralParser) setInitialState(lastMessage *api.Message) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
if !p.HasThinkingSupport() {
p.state = ministralCollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = ministralCollectingContent
return
}
p.state = ministralCollectingThinkingContent
}
func (p *MinistralParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.setInitialState(lastMessage)
return tools
}
func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
for i := range tools {
if tools[i].Function.Name == n {
return &tools[i], nil
}
}
return nil, fmt.Errorf("tool '%s' not found", n)
}
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
switch p.state {
case ministralCollectingContent:
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
if before != "" {
return before, "", calls, nil
}
p.state = ministralCollectingToolName
} else if strings.Contains(p.buffer.String(), "[THINK]") {
p.state = ministralCollectingThinkingContent
return "", "", calls, nil
} else {
p.buffer.Reset()
return s, "", calls, nil
}
case ministralCollectingThinkingContent:
if strings.Contains(p.buffer.String(), "[/THINK]") {
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
p.state = ministralCollectingContent
if after != "" {
p.buffer.Reset()
return after, thinkingContent, calls, nil
}
return "", thinkingContent, calls, nil
} else {
p.buffer.Reset()
return "", s, calls, nil
}
case ministralCollectingToolName:
if strings.Contains(p.buffer.String(), "[ARGS]") {
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
t, err := toolByName(p.tools, name)
if err != nil {
return "", "", calls, err
}
p.currentTool = t
p.state = ministralCollectingToolArgs
return "", "", calls, nil
}
return "", "", calls, nil
case ministralCollectingToolArgs:
if strings.Contains(p.buffer.String(), "}") {
before, _ := splitAtTag(&p.buffer, "}", false)
before += "}"
var data map[string]any
if err := json.Unmarshal([]byte(before), &data); err != nil {
// todo - throw a better error
return "", "", calls, err
}
p.state = ministralCollectingContent
call := api.ToolCall{
Function: api.ToolCallFunction{
Name: p.currentTool.Function.Name,
Arguments: api.ToolCallFunctionArguments(data),
},
}
calls = append(calls, call)
return "", "", calls, nil
}
return "", "", calls, nil
}
return p.buffer.String(), thinking, calls, nil
}

View File

@@ -1,6 +1,9 @@
package parsers
import (
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
)
@@ -38,16 +41,17 @@ func ParserForName(name string) Parser {
if parser, ok := registry.constructors[name]; ok {
return parser()
}
var p Parser
switch name {
case "qwen3-coder":
parser := &Qwen3CoderParser{}
return parser
p = &Qwen3CoderParser{}
case "qwen3-vl-instruct":
parser := &Qwen3VLParser{hasThinkingSupport: false}
return parser
p = &Qwen3VLParser{hasThinkingSupport: false}
case "qwen3-vl-thinking":
parser := &Qwen3VLParser{hasThinkingSupport: true}
return parser
p = &Qwen3VLParser{hasThinkingSupport: true}
case "ministral":
p = &MinistralParser{hasThinkingSupport: false}
case "passthrough":
return &PassthroughParser{}
case "harmony":
@@ -57,6 +61,7 @@ func ParserForName(name string) Parser {
default:
return nil
}
return p
}
type PassthroughParser struct{}
@@ -76,3 +81,20 @@ func (p *PassthroughParser) HasToolSupport() bool {
func (p *PassthroughParser) HasThinkingSupport() bool {
return false
}
func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string) {
split := strings.SplitN(sb.String(), tag, 2)
if len(split) == 1 {
sb.Reset()
return split[0], ""
}
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
sb.Reset()
sb.WriteString(after)
return before, after // return events
}

View File

@@ -1,6 +1,7 @@
package parsers
import (
"strings"
"testing"
"github.com/ollama/ollama/api"
@@ -95,3 +96,164 @@ func TestUnknownParserReturnsNil(t *testing.T) {
t.Error("expected nil for unknown parser")
}
}
func TestSplitAtTag(t *testing.T) {
tests := []struct {
name string
input string
tag string
trimAfter bool
wantBefore string
wantAfter string
wantSB string // expected content of strings.Builder after operation
}{
{
name: "basic split with trimAfter true",
input: "hello <!-- split --> world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "world",
wantSB: "world",
},
{
name: "basic split with trimAfter false",
input: "hello <!-- split --> world",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "hello",
wantAfter: " world",
wantSB: " world",
},
{
name: "tag at beginning with trimAfter true",
input: "<!-- split -->world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "",
wantAfter: "world",
wantSB: "world",
},
{
name: "tag at beginning with trimAfter false",
input: "<!-- split --> world",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "",
wantAfter: " world",
wantSB: " world",
},
{
name: "tag at end with trimAfter true",
input: "hello <!-- split -->",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "",
wantSB: "",
},
{
name: "tag at end with trimAfter false",
input: "hello <!-- split -->",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "hello",
wantAfter: "",
wantSB: "",
},
{
name: "multiple tags splits at first occurrence",
input: "hello <!-- split --> world <!-- split --> end",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "world <!-- split --> end",
wantSB: "world <!-- split --> end",
},
{
name: "tag not present",
input: "hello world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello world",
wantAfter: "",
wantSB: "",
},
{
name: "empty input",
input: "",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "",
wantAfter: "",
wantSB: "",
},
{
name: "only whitespace before tag",
input: " \t\n<!-- split -->world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "",
wantAfter: "world",
wantSB: "world",
},
{
name: "only whitespace after tag with trimAfter true",
input: "hello<!-- split --> \t\n",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "",
wantSB: "",
},
{
name: "only whitespace after tag with trimAfter false",
input: "hello<!-- split --> \t\n",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "hello",
wantAfter: " \t\n",
wantSB: " \t\n",
},
{
name: "complex whitespace trimming",
input: " hello \t\n <!-- split --> \n\t world ",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: " hello",
wantAfter: "world ",
wantSB: "world ",
},
{
name: "tag with special characters",
input: "text <tag attr=\"value\"> more text",
tag: "<tag attr=\"value\">",
trimAfter: true,
wantBefore: "text",
wantAfter: "more text",
wantSB: "more text",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sb := &strings.Builder{}
sb.WriteString(tt.input)
before, after := splitAtTag(sb, tt.tag, tt.trimAfter)
// Check return values
if before != tt.wantBefore {
t.Errorf("splitAtTag() before = %q, want %q", before, tt.wantBefore)
}
if after != tt.wantAfter {
t.Errorf("splitAtTag() after = %q, want %q", after, tt.wantAfter)
}
// Check strings.Builder state
if sb.String() != tt.wantSB {
t.Errorf("strings.Builder after split = %q, want %q", sb.String(), tt.wantSB)
}
})
}
}

View File

@@ -70,7 +70,6 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
@@ -81,7 +80,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
calls = append(calls, toolCall)
case qwenEventThinkingContent:
thinkingSb.WriteString(event.content)
case qwenEventContent:
@@ -91,7 +90,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
return contentSb.String(), thinkingSb.String(), calls, nil
}
func (p *Qwen3VLParser) parseEvents() []qwenEvent {
@@ -113,19 +112,6 @@ func (p *Qwen3VLParser) parseEvents() []qwenEvent {
return all
}
func splitAtTag(p *Qwen3VLParser, tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after // return events
}
func (p *Qwen3VLParser) eatLeadingWhitespaceAndTransitionTo(nextState qwenParserState) ([]qwenEvent, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
@@ -144,7 +130,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
case CollectingContent:
if strings.Contains(p.buffer.String(), toolOpenTag) {
// events = emitContentBeforeTag(p, events, toolOpenTag)
before, _ := splitAtTag(p, toolOpenTag, false)
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
if len(before) > 0 {
events = append(events, qwenEventContent{content: before})
}
@@ -195,7 +181,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
}
case CollectingThinkingContent:
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
thinking, remaining := splitAtTag(p, thinkingCloseTag, true)
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, qwenEventThinkingContent{content: thinking})
}