mirror of
https://github.com/ollama/ollama.git
synced 2026-01-05 22:19:45 -05:00
Compare commits
4 Commits
parth/rend
...
v0.13.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5317202c38 | ||
|
|
d771043e88 | ||
|
|
f8f1071818 | ||
|
|
d3e0a0dee4 |
@@ -11,7 +11,6 @@ linters:
|
||||
- errorlint
|
||||
- exptostd
|
||||
- gocheckcompilerdirectives
|
||||
- gocritic
|
||||
- govet
|
||||
- ineffassign
|
||||
- intrange
|
||||
|
||||
@@ -1,625 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "transformers>=4.57.0",
|
||||
# "jinja2",
|
||||
# "fastapi",
|
||||
# "uvicorn",
|
||||
# "pydantic",
|
||||
# "requests",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Chat Template Testing Tool
|
||||
|
||||
Test HuggingFace chat templates against Ollama renderers.
|
||||
|
||||
Usage:
|
||||
# Run predefined test cases against a HuggingFace model
|
||||
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3
|
||||
|
||||
# Compare HuggingFace output with Ollama renderer
|
||||
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --ollama-model intellect3
|
||||
|
||||
# Start server for manual curl testing
|
||||
uv run cmd/chat_template/chat_template.py --serve
|
||||
|
||||
# Show chat template for a model
|
||||
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --show-template
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
{
|
||||
"name": "basic_user_message",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"tools": None,
|
||||
},
|
||||
{
|
||||
"name": "with_system_message",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
],
|
||||
"tools": None,
|
||||
},
|
||||
{
|
||||
"name": "multi_turn_conversation",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
"tools": None,
|
||||
},
|
||||
{
|
||||
"name": "with_tools",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the weather?"},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city"}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "tool_call_and_response",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in SF?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me check the weather.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "San Francisco"},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city"}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "parallel_tool_calls",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Get weather in SF and NYC"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "call_2",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "New York"},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
|
||||
{"role": "tool", "content": '{"temperature": 55}', "tool_call_id": "call_2"},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Thinking tests
|
||||
{
|
||||
"name": "assistant_with_thinking",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The answer is 4.",
|
||||
"thinking": "Let me calculate: 2 + 2 = 4. This is basic arithmetic.",
|
||||
},
|
||||
{"role": "user", "content": "And 3+3?"},
|
||||
],
|
||||
"tools": None,
|
||||
},
|
||||
{
|
||||
"name": "thinking_with_tool_call",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll check the weather for you.",
|
||||
"thinking": "The user wants to know the weather in Paris. I should call the get_weather function.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "Paris"},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": '{"temperature": 18, "condition": "cloudy"}', "tool_call_id": "call_1"},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "thinking_only_no_content",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Think about this silently."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "", # HuggingFace requires content field
|
||||
"thinking": "I'm thinking about this but won't respond with visible content.",
|
||||
},
|
||||
{"role": "user", "content": "What did you think?"},
|
||||
],
|
||||
"tools": None,
|
||||
},
|
||||
]
|
||||
|
||||
# Cache for tokenizers
|
||||
_tokenizer_cache: dict[str, Any] = {}
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str):
|
||||
"""Get or create tokenizer for the given model."""
|
||||
if model_name not in _tokenizer_cache:
|
||||
print(f"Loading tokenizer for {model_name}...", file=sys.stderr)
|
||||
_tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(model_name)
|
||||
return _tokenizer_cache[model_name]
|
||||
|
||||
|
||||
def apply_template(
|
||||
model: str,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
) -> str:
|
||||
"""Apply HuggingFace chat template to messages."""
|
||||
tokenizer = get_tokenizer(model)
|
||||
|
||||
if tools:
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
|
||||
def get_ollama_prompt(
|
||||
ollama_model: str,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "http://localhost:11434",
|
||||
) -> str | None:
|
||||
"""Get rendered prompt from Ollama using debug_render_only."""
|
||||
import requests
|
||||
|
||||
# Convert messages to Ollama format
|
||||
ollama_messages = []
|
||||
for msg in messages:
|
||||
ollama_msg = {"role": msg["role"]}
|
||||
if "content" in msg:
|
||||
ollama_msg["content"] = msg["content"]
|
||||
if "thinking" in msg:
|
||||
ollama_msg["thinking"] = msg["thinking"]
|
||||
if "tool_calls" in msg:
|
||||
# Convert tool_calls to Ollama format
|
||||
tool_calls = []
|
||||
for tc in msg["tool_calls"]:
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"],
|
||||
}
|
||||
}
|
||||
if "id" in tc:
|
||||
tool_call["id"] = tc["id"]
|
||||
tool_calls.append(tool_call)
|
||||
ollama_msg["tool_calls"] = tool_calls
|
||||
if "tool_call_id" in msg:
|
||||
ollama_msg["tool_call_id"] = msg["tool_call_id"]
|
||||
ollama_messages.append(ollama_msg)
|
||||
|
||||
payload = {
|
||||
"model": ollama_model,
|
||||
"messages": ollama_messages,
|
||||
"stream": False,
|
||||
"_debug_render_only": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
try:
|
||||
resp = requests.post(f"{ollama_host}/api/chat", json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Field name is _debug_info with underscore prefix
|
||||
if "_debug_info" in data and "rendered_template" in data["_debug_info"]:
|
||||
return data["_debug_info"]["rendered_template"]
|
||||
return None
|
||||
except requests.exceptions.ConnectionError:
|
||||
print(f" [ERROR] Cannot connect to Ollama at {ollama_host}", file=sys.stderr)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f" [ERROR] Ollama request failed: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def compute_diff(hf_prompt: str, ollama_prompt: str) -> str:
|
||||
"""Compute a unified diff between HuggingFace and Ollama prompts."""
|
||||
import difflib
|
||||
|
||||
hf_lines = hf_prompt.splitlines(keepends=True)
|
||||
ollama_lines = ollama_prompt.splitlines(keepends=True)
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
ollama_lines,
|
||||
hf_lines,
|
||||
fromfile="Ollama",
|
||||
tofile="HuggingFace",
|
||||
lineterm="",
|
||||
)
|
||||
return "".join(diff)
|
||||
|
||||
|
||||
def print_test_output(
|
||||
name: str,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None,
|
||||
hf_prompt: str,
|
||||
ollama_prompt: str | None = None,
|
||||
as_repr: bool = False,
|
||||
):
|
||||
"""Print test output in a format suitable for Go test creation and LLM diffing."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Test: {name}")
|
||||
print("=" * 60)
|
||||
print("\n--- Input Messages ---")
|
||||
print(json.dumps(messages, indent=2))
|
||||
if tools:
|
||||
print("\n--- Tools ---")
|
||||
print(json.dumps(tools, indent=2))
|
||||
|
||||
if ollama_prompt is not None:
|
||||
# Comparison mode
|
||||
if hf_prompt == ollama_prompt:
|
||||
print("\n--- Result: MATCH ---")
|
||||
print("\n--- Prompt (both identical) ---")
|
||||
if as_repr:
|
||||
print(repr(hf_prompt))
|
||||
else:
|
||||
print(hf_prompt)
|
||||
else:
|
||||
print("\n--- Result: MISMATCH ---")
|
||||
print("\n--- HuggingFace Prompt ---")
|
||||
if as_repr:
|
||||
print(repr(hf_prompt))
|
||||
else:
|
||||
print(hf_prompt)
|
||||
print("\n--- Ollama Prompt ---")
|
||||
if as_repr:
|
||||
print(repr(ollama_prompt))
|
||||
else:
|
||||
print(ollama_prompt)
|
||||
print("\n--- Diff (Ollama -> HuggingFace) ---")
|
||||
diff = compute_diff(hf_prompt, ollama_prompt)
|
||||
if diff:
|
||||
print(diff)
|
||||
else:
|
||||
print("(no line-level diff, check whitespace)")
|
||||
else:
|
||||
# HuggingFace only mode
|
||||
print("\n--- HuggingFace Prompt ---")
|
||||
if as_repr:
|
||||
print(repr(hf_prompt))
|
||||
else:
|
||||
print(hf_prompt)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def run_tests(
|
||||
model: str,
|
||||
as_repr: bool = False,
|
||||
test_filter: str | None = None,
|
||||
ollama_model: str | None = None,
|
||||
ollama_host: str = "http://localhost:11434",
|
||||
):
|
||||
"""Run all predefined test cases against a model."""
|
||||
if ollama_model:
|
||||
print(f"\nComparing HuggingFace ({model}) vs Ollama ({ollama_model})\n")
|
||||
else:
|
||||
print(f"\nRunning tests against: {model}\n")
|
||||
|
||||
matches = 0
|
||||
mismatches = 0
|
||||
errors = 0
|
||||
|
||||
for test_case in TEST_CASES:
|
||||
name = test_case["name"]
|
||||
messages = test_case["messages"]
|
||||
tools = test_case["tools"]
|
||||
|
||||
# Filter tests if specified
|
||||
if test_filter and test_filter.lower() not in name.lower():
|
||||
continue
|
||||
|
||||
try:
|
||||
hf_prompt = apply_template(model, messages, tools)
|
||||
|
||||
ollama_prompt = None
|
||||
if ollama_model:
|
||||
ollama_prompt = get_ollama_prompt(
|
||||
ollama_model, messages, tools, ollama_host
|
||||
)
|
||||
if ollama_prompt is None:
|
||||
errors += 1
|
||||
elif hf_prompt == ollama_prompt:
|
||||
matches += 1
|
||||
else:
|
||||
mismatches += 1
|
||||
|
||||
print_test_output(
|
||||
name, messages, tools, hf_prompt, ollama_prompt, as_repr=as_repr
|
||||
)
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Test: {name} - FAILED")
|
||||
print(f"--- Input Messages ---")
|
||||
print(json.dumps(messages, indent=2))
|
||||
if tools:
|
||||
print(f"--- Tools ---")
|
||||
print(json.dumps(tools, indent=2))
|
||||
print(f"--- Error ---")
|
||||
print(f"{e}")
|
||||
print("=" * 60)
|
||||
|
||||
# Print summary if comparing
|
||||
if ollama_model:
|
||||
total = matches + mismatches + errors
|
||||
print(f"\n{'='*60}")
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f" Total: {total}")
|
||||
print(f" Matches: {matches}")
|
||||
print(f" Mismatches: {mismatches}")
|
||||
print(f" Errors: {errors}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def show_template(model: str):
|
||||
"""Show the chat template for a model."""
|
||||
tokenizer = get_tokenizer(model)
|
||||
print(f"\nChat template for {model}:\n")
|
||||
print("-" * 60)
|
||||
print(tokenizer.chat_template)
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
def start_server(host: str = "0.0.0.0", port: int = 8000):
|
||||
"""Start the FastAPI server for manual testing."""
|
||||
from typing import Optional, List, Dict, Any as TypingAny
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, TypingAny]]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
|
||||
class GeneratePromptRequest(BaseModel):
|
||||
messages: List[Message]
|
||||
model: str = "PrimeIntellect/INTELLECT-3"
|
||||
tools: Optional[List[Dict[str, TypingAny]]] = None
|
||||
inject_tools_as_functions: bool = False
|
||||
|
||||
class GeneratePromptResponse(BaseModel):
|
||||
prompt: str
|
||||
model: str
|
||||
|
||||
app = FastAPI(title="HuggingFace Prompt Generator", version="1.0.0")
|
||||
|
||||
@app.post("/generate-prompt", response_model=GeneratePromptResponse)
|
||||
async def generate_prompt(request: GeneratePromptRequest):
|
||||
try:
|
||||
messages = []
|
||||
for msg in request.messages:
|
||||
message_dict = {"role": msg.role}
|
||||
if msg.content is not None:
|
||||
message_dict["content"] = msg.content
|
||||
if msg.tool_calls is not None:
|
||||
tool_calls = []
|
||||
for tc in msg.tool_calls:
|
||||
tc_copy = tc.copy()
|
||||
if "function" in tc_copy and "arguments" in tc_copy["function"]:
|
||||
args = tc_copy["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
tc_copy["function"]["arguments"] = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
tool_calls.append(tc_copy)
|
||||
message_dict["tool_calls"] = tool_calls
|
||||
if msg.tool_call_id is not None:
|
||||
message_dict["tool_call_id"] = msg.tool_call_id
|
||||
messages.append(message_dict)
|
||||
|
||||
prompt = apply_template(request.model, messages, request.tools)
|
||||
return GeneratePromptResponse(prompt=prompt, model=request.model)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
print(f"Starting server on http://{host}:{port}")
|
||||
print("Endpoints:")
|
||||
print(" POST /generate-prompt - Generate prompt from messages")
|
||||
print(" GET /health - Health check")
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="HuggingFace Prompt Testing Tool",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
type=str,
|
||||
help="HuggingFace model name (e.g., PrimeIntellect/INTELLECT-3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ollama-model",
|
||||
"-o",
|
||||
type=str,
|
||||
help="Ollama model name to compare against (e.g., qwen3-coder)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ollama-host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="Ollama server URL (default: http://localhost:11434)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve",
|
||||
"-s",
|
||||
action="store_true",
|
||||
help="Start FastAPI server for manual curl testing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
"-p",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Server port (default: 8000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-template",
|
||||
"-t",
|
||||
action="store_true",
|
||||
help="Show the chat template for the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repr",
|
||||
"-r",
|
||||
action="store_true",
|
||||
help="Output prompts as Python repr (shows escape sequences)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter",
|
||||
"-f",
|
||||
type=str,
|
||||
help="Filter tests by name (substring match)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.serve:
|
||||
start_server(port=args.port)
|
||||
elif args.model:
|
||||
if args.show_template:
|
||||
show_template(args.model)
|
||||
else:
|
||||
run_tests(
|
||||
args.model,
|
||||
as_repr=args.repr,
|
||||
test_filter=args.filter,
|
||||
ollama_model=args.ollama_model,
|
||||
ollama_host=args.ollama_host,
|
||||
)
|
||||
else:
|
||||
parser.print_help()
|
||||
print("\nExample usage:")
|
||||
print(" uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3")
|
||||
print(" uv run cmd/chat_template/chat_template.py --model Qwen/Qwen3-Coder-480B-A35B-Instruct --ollama-model qwen3-coder")
|
||||
print(" uv run cmd/chat_template/chat_template.py --serve")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1430,7 +1430,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
@@ -29,6 +29,15 @@ type mistral3Model struct {
|
||||
SlidingWindow *uint32 `json:"sliding_window"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RopeParameters struct {
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
Factor float32 `json:"factor"`
|
||||
ScalingBeta float32 `json:"llama_4_scaling_beta"`
|
||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
@@ -61,8 +70,13 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
|
||||
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
|
||||
|
||||
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
|
||||
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
|
||||
}
|
||||
|
||||
// Vision configuration
|
||||
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
|
||||
@@ -33,6 +33,9 @@ func TestVisionModels(t *testing.T) {
|
||||
// Qwen 3 VL mixture of experts
|
||||
model: "qwen3-vl:30b",
|
||||
},
|
||||
{
|
||||
model: "ministral-3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, v := range testCases {
|
||||
|
||||
@@ -38,6 +38,7 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"ministral-3",
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
"gemma3n:e2b",
|
||||
@@ -167,6 +168,7 @@ var (
|
||||
"medllama2",
|
||||
"megadolphin",
|
||||
"minicpm-v",
|
||||
"ministral-3",
|
||||
"mistral-large",
|
||||
"mistral-nemo",
|
||||
"mistral-openorca",
|
||||
@@ -270,6 +272,7 @@ var (
|
||||
"mistral",
|
||||
"qwen2.5",
|
||||
"qwen2",
|
||||
"ministral-3",
|
||||
"mistral-nemo",
|
||||
"mistral-small",
|
||||
"mixtral:8x22b",
|
||||
|
||||
@@ -874,7 +874,7 @@ func (s *llmServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.Devic
|
||||
}}
|
||||
}
|
||||
gpuLayers, layers := s.buildLayout(systemGPUs, memory, requireFull, backoff)
|
||||
err := s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
|
||||
err := s.verifyLayout(systemInfo, systemGPUs, memory, requireFull, gpuLayers, layers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -943,7 +943,7 @@ func (s *llmServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMe
|
||||
}
|
||||
|
||||
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
|
||||
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
|
||||
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
|
||||
// These sizes will only increase as we go through additional iterations and get additional information.
|
||||
cpuSize := memory.InputWeights + memory.CPU.Graph
|
||||
var vramSize uint64
|
||||
@@ -970,8 +970,8 @@ nextLayer:
|
||||
}
|
||||
|
||||
if requireFull {
|
||||
if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
|
||||
slog.Info("model requires more memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
|
||||
if len(systemGPUs) > 0 && gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
|
||||
slog.Info("model requires more gpu memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
|
||||
return ErrLoadRequiredFull
|
||||
}
|
||||
|
||||
@@ -998,7 +998,7 @@ nextLayer:
|
||||
}
|
||||
}
|
||||
|
||||
if gpuLayers.Sum() == 0 {
|
||||
if len(systemGPUs) > 0 && gpuLayers.Sum() == 0 {
|
||||
slog.Debug("insufficient VRAM to load any model layers")
|
||||
}
|
||||
|
||||
|
||||
@@ -26,10 +26,11 @@ func TestLLMServerFitGPU(t *testing.T) {
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "No GPU",
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{},
|
||||
name: "No GPU",
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{},
|
||||
requireFull: true, // Should not try to evict even though we can't load any layers
|
||||
},
|
||||
{
|
||||
name: "Full single GPU",
|
||||
|
||||
@@ -509,11 +509,9 @@ func GetVisibleDevicesEnv(l []DeviceInfo) map[string]string {
|
||||
// to crash at inference time and requires deeper validation before we include
|
||||
// it in the supported devices list.
|
||||
func (d DeviceInfo) NeedsInitValidation() bool {
|
||||
// At this time the only library we know needs a 2nd pass is ROCm since
|
||||
// rocblas will crash on unsupported devices. We want to find those crashes
|
||||
// during bootstrap discovery so we can eliminate those GPUs before the user
|
||||
// tries to run inference on them
|
||||
return d.Library == "ROCm"
|
||||
// ROCm: rocblas will crash on unsupported devices.
|
||||
// CUDA: verify CC is supported by the version of the library
|
||||
return d.Library == "ROCm" || d.Library == "CUDA"
|
||||
}
|
||||
|
||||
// Set the init validation environment variable
|
||||
|
||||
@@ -159,8 +159,9 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
positionsScale := m.getScale(ctx, batch.Positions)
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, positionsScale, batch.Outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -16,6 +16,8 @@ type TextOptions struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
headDim, ropeDim int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeOrigPosEmbeddings int
|
||||
ropeScalingBeta float32
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
@@ -34,7 +36,7 @@ type SelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||
|
||||
@@ -49,6 +51,10 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
if opts.ropeOrigPosEmbeddings > 0 {
|
||||
q = q.Mul(ctx, positionsScale)
|
||||
}
|
||||
|
||||
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
@@ -76,11 +82,11 @@ type Layer struct {
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, positionsScale, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
@@ -97,7 +103,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
||||
|
||||
// image embeddings
|
||||
@@ -114,25 +120,36 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, positionsScale, lastLayerOutputs, cache, m.TextOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
func (m *TextModel) getScale(ctx ml.Context, positions []int32) ml.Tensor {
|
||||
posScale := make([]float32, len(positions))
|
||||
for n, pos := range positions {
|
||||
interval := math.Floor(float64(pos) / float64(m.ropeOrigPosEmbeddings))
|
||||
posScale[n] = float32(1.0 + float64(m.ropeScalingBeta)*math.Log(1.0+interval))
|
||||
}
|
||||
return ctx.Input().FromFloats(posScale, 1, 1, len(posScale))
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
return &TextModel{
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
TextOptions: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
|
||||
ropeScalingBeta: c.Float("rope.scaling_beta"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
136
model/parsers/ministral.go
Normal file
136
model/parsers/ministral.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type ministralParserState int
|
||||
|
||||
const (
|
||||
ministralCollectingContent = iota
|
||||
ministralCollectingThinkingContent
|
||||
ministralCollectingToolName
|
||||
ministralCollectingToolArgs
|
||||
)
|
||||
|
||||
type MinistralParser struct {
|
||||
state ministralParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
currentTool *api.Tool
|
||||
}
|
||||
|
||||
func (p *MinistralParser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *MinistralParser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *MinistralParser) setInitialState(lastMessage *api.Message) {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
if !p.HasThinkingSupport() {
|
||||
p.state = ministralCollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = ministralCollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
p.state = ministralCollectingThinkingContent
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.setInitialState(lastMessage)
|
||||
return tools
|
||||
}
|
||||
|
||||
func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == n {
|
||||
return &tools[i], nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("tool '%s' not found", n)
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
|
||||
switch p.state {
|
||||
case ministralCollectingContent:
|
||||
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
|
||||
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
|
||||
if before != "" {
|
||||
return before, "", calls, nil
|
||||
}
|
||||
p.state = ministralCollectingToolName
|
||||
} else if strings.Contains(p.buffer.String(), "[THINK]") {
|
||||
p.state = ministralCollectingThinkingContent
|
||||
return "", "", calls, nil
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
return s, "", calls, nil
|
||||
}
|
||||
case ministralCollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), "[/THINK]") {
|
||||
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
|
||||
p.state = ministralCollectingContent
|
||||
if after != "" {
|
||||
p.buffer.Reset()
|
||||
return after, thinkingContent, calls, nil
|
||||
}
|
||||
return "", thinkingContent, calls, nil
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
return "", s, calls, nil
|
||||
}
|
||||
case ministralCollectingToolName:
|
||||
if strings.Contains(p.buffer.String(), "[ARGS]") {
|
||||
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
|
||||
|
||||
t, err := toolByName(p.tools, name)
|
||||
if err != nil {
|
||||
return "", "", calls, err
|
||||
}
|
||||
p.currentTool = t
|
||||
p.state = ministralCollectingToolArgs
|
||||
return "", "", calls, nil
|
||||
}
|
||||
return "", "", calls, nil
|
||||
case ministralCollectingToolArgs:
|
||||
if strings.Contains(p.buffer.String(), "}") {
|
||||
before, _ := splitAtTag(&p.buffer, "}", false)
|
||||
before += "}"
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(before), &data); err != nil {
|
||||
// todo - throw a better error
|
||||
return "", "", calls, err
|
||||
}
|
||||
|
||||
p.state = ministralCollectingContent
|
||||
|
||||
call := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: p.currentTool.Function.Name,
|
||||
Arguments: api.ToolCallFunctionArguments(data),
|
||||
},
|
||||
}
|
||||
calls = append(calls, call)
|
||||
return "", "", calls, nil
|
||||
}
|
||||
return "", "", calls, nil
|
||||
}
|
||||
|
||||
return p.buffer.String(), thinking, calls, nil
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
)
|
||||
@@ -38,16 +41,17 @@ func ParserForName(name string) Parser {
|
||||
if parser, ok := registry.constructors[name]; ok {
|
||||
return parser()
|
||||
}
|
||||
var p Parser
|
||||
|
||||
switch name {
|
||||
case "qwen3-coder":
|
||||
parser := &Qwen3CoderParser{}
|
||||
return parser
|
||||
p = &Qwen3CoderParser{}
|
||||
case "qwen3-vl-instruct":
|
||||
parser := &Qwen3VLParser{hasThinkingSupport: false}
|
||||
return parser
|
||||
p = &Qwen3VLParser{hasThinkingSupport: false}
|
||||
case "qwen3-vl-thinking":
|
||||
parser := &Qwen3VLParser{hasThinkingSupport: true}
|
||||
return parser
|
||||
p = &Qwen3VLParser{hasThinkingSupport: true}
|
||||
case "ministral":
|
||||
p = &MinistralParser{hasThinkingSupport: false}
|
||||
case "passthrough":
|
||||
return &PassthroughParser{}
|
||||
case "harmony":
|
||||
@@ -57,6 +61,7 @@ func ParserForName(name string) Parser {
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
type PassthroughParser struct{}
|
||||
@@ -76,3 +81,20 @@ func (p *PassthroughParser) HasToolSupport() bool {
|
||||
func (p *PassthroughParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(sb.String(), tag, 2)
|
||||
if len(split) == 1 {
|
||||
sb.Reset()
|
||||
return split[0], ""
|
||||
}
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
sb.Reset()
|
||||
sb.WriteString(after)
|
||||
return before, after // return events
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -95,3 +96,164 @@ func TestUnknownParserReturnsNil(t *testing.T) {
|
||||
t.Error("expected nil for unknown parser")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitAtTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
tag string
|
||||
trimAfter bool
|
||||
wantBefore string
|
||||
wantAfter string
|
||||
wantSB string // expected content of strings.Builder after operation
|
||||
}{
|
||||
{
|
||||
name: "basic split with trimAfter true",
|
||||
input: "hello <!-- split --> world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "world",
|
||||
wantSB: "world",
|
||||
},
|
||||
{
|
||||
name: "basic split with trimAfter false",
|
||||
input: "hello <!-- split --> world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "hello",
|
||||
wantAfter: " world",
|
||||
wantSB: " world",
|
||||
},
|
||||
{
|
||||
name: "tag at beginning with trimAfter true",
|
||||
input: "<!-- split -->world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "",
|
||||
wantAfter: "world",
|
||||
wantSB: "world",
|
||||
},
|
||||
{
|
||||
name: "tag at beginning with trimAfter false",
|
||||
input: "<!-- split --> world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "",
|
||||
wantAfter: " world",
|
||||
wantSB: " world",
|
||||
},
|
||||
{
|
||||
name: "tag at end with trimAfter true",
|
||||
input: "hello <!-- split -->",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "tag at end with trimAfter false",
|
||||
input: "hello <!-- split -->",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "multiple tags splits at first occurrence",
|
||||
input: "hello <!-- split --> world <!-- split --> end",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "world <!-- split --> end",
|
||||
wantSB: "world <!-- split --> end",
|
||||
},
|
||||
{
|
||||
name: "tag not present",
|
||||
input: "hello world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello world",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "only whitespace before tag",
|
||||
input: " \t\n<!-- split -->world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "",
|
||||
wantAfter: "world",
|
||||
wantSB: "world",
|
||||
},
|
||||
{
|
||||
name: "only whitespace after tag with trimAfter true",
|
||||
input: "hello<!-- split --> \t\n",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "only whitespace after tag with trimAfter false",
|
||||
input: "hello<!-- split --> \t\n",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "hello",
|
||||
wantAfter: " \t\n",
|
||||
wantSB: " \t\n",
|
||||
},
|
||||
{
|
||||
name: "complex whitespace trimming",
|
||||
input: " hello \t\n <!-- split --> \n\t world ",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: " hello",
|
||||
wantAfter: "world ",
|
||||
wantSB: "world ",
|
||||
},
|
||||
{
|
||||
name: "tag with special characters",
|
||||
input: "text <tag attr=\"value\"> more text",
|
||||
tag: "<tag attr=\"value\">",
|
||||
trimAfter: true,
|
||||
wantBefore: "text",
|
||||
wantAfter: "more text",
|
||||
wantSB: "more text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sb := &strings.Builder{}
|
||||
sb.WriteString(tt.input)
|
||||
|
||||
before, after := splitAtTag(sb, tt.tag, tt.trimAfter)
|
||||
|
||||
// Check return values
|
||||
if before != tt.wantBefore {
|
||||
t.Errorf("splitAtTag() before = %q, want %q", before, tt.wantBefore)
|
||||
}
|
||||
if after != tt.wantAfter {
|
||||
t.Errorf("splitAtTag() after = %q, want %q", after, tt.wantAfter)
|
||||
}
|
||||
|
||||
// Check strings.Builder state
|
||||
if sb.String() != tt.wantSB {
|
||||
t.Errorf("strings.Builder after split = %q, want %q", sb.String(), tt.wantSB)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +70,6 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
@@ -81,7 +80,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
calls = append(calls, toolCall)
|
||||
case qwenEventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case qwenEventContent:
|
||||
@@ -91,7 +90,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) parseEvents() []qwenEvent {
|
||||
@@ -113,19 +112,6 @@ func (p *Qwen3VLParser) parseEvents() []qwenEvent {
|
||||
return all
|
||||
}
|
||||
|
||||
func splitAtTag(p *Qwen3VLParser, tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after // return events
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) eatLeadingWhitespaceAndTransitionTo(nextState qwenParserState) ([]qwenEvent, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
@@ -144,7 +130,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
case CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||
// events = emitContentBeforeTag(p, events, toolOpenTag)
|
||||
before, _ := splitAtTag(p, toolOpenTag, false)
|
||||
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwenEventContent{content: before})
|
||||
}
|
||||
@@ -195,7 +181,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
}
|
||||
case CollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||
thinking, remaining := splitAtTag(p, thinkingCloseTag, true)
|
||||
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: thinking})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user