mirror of
https://github.com/ollama/ollama.git
synced 2026-02-17 11:03:30 -05:00
Compare commits
2 Commits
pdevine/qw
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a88f7eb20 | ||
|
|
0d5da826d4 |
@@ -182,10 +182,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
case "parser":
|
||||
mfConfig.Parser = cmd.Args
|
||||
case "renderer":
|
||||
mfConfig.Renderer = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -45,10 +45,6 @@ func ParserForName(name string) Parser {
|
||||
var p Parser
|
||||
|
||||
switch name {
|
||||
case "qwen3":
|
||||
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
case "qwen3-thinking":
|
||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
case "qwen3-coder":
|
||||
p = &Qwen3CoderParser{}
|
||||
case "qwen3-vl-instruct":
|
||||
|
||||
@@ -54,8 +54,6 @@ func TestBuiltInParsersStillWork(t *testing.T) {
|
||||
name string
|
||||
}{
|
||||
{"passthrough"},
|
||||
{"qwen3"},
|
||||
{"qwen3-thinking"},
|
||||
{"qwen3-coder"},
|
||||
{"harmony"},
|
||||
}
|
||||
|
||||
@@ -1,335 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type qwen3ParserState int
|
||||
|
||||
const (
|
||||
qwen3ParserStateLookingForThinkingOpen qwen3ParserState = iota
|
||||
qwen3ParserStateThinkingStartedEatingWhitespace
|
||||
qwen3ParserStateCollectingThinking
|
||||
qwen3ParserStateThinkingDoneEatingWhitespace
|
||||
qwen3ParserStateCollectingContent
|
||||
qwen3ParserStateToolStartedEatingWhitespace
|
||||
qwen3ParserStateCollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
qwen3ThinkingOpenTag = "<think>"
|
||||
qwen3ThinkingCloseTag = "</think>"
|
||||
qwen3ToolOpenTag = "<tool_call>"
|
||||
qwen3ToolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
// Qwen3Parser parses Qwen3 output to extract thinking and tool calls.
|
||||
// Qwen3 prompts end with <think> when thinking is enabled, so output begins
|
||||
// with thinking content directly (without an opening tag).
|
||||
type Qwen3Parser struct {
|
||||
state qwen3ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
defaultThinking bool
|
||||
maybeThinkingOpenAtBOL bool
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.buffer.Reset()
|
||||
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
if thinkValue == nil {
|
||||
thinkingEnabled = p.defaultThinking
|
||||
}
|
||||
|
||||
if p.hasThinkingSupport && thinkingEnabled {
|
||||
p.state = qwen3ParserStateCollectingThinking
|
||||
p.maybeThinkingOpenAtBOL = true
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
type qwen3Event interface {
|
||||
isQwen3Event()
|
||||
}
|
||||
|
||||
type qwen3EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen3EventContent) isQwen3Event() {}
|
||||
|
||||
type qwen3EventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (qwen3EventRawToolCall) isQwen3Event() {}
|
||||
|
||||
type qwen3EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen3EventThinkingContent) isQwen3Event() {}
|
||||
|
||||
func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case qwen3EventRawToolCall:
|
||||
toolCall, err := parseQwen3ToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
calls = append(calls, toolCall)
|
||||
case qwen3EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case qwen3EventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) parseEvents() []qwen3Event {
|
||||
var all []qwen3Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []qwen3Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen3ParserState) ([]qwen3Event, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||
return splitAtTag(&p.buffer, tag, trimAfter)
|
||||
}
|
||||
|
||||
func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
var events []qwen3Event
|
||||
|
||||
switch p.state {
|
||||
case qwen3ParserStateLookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
|
||||
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
return events, false
|
||||
}
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
return events, true
|
||||
|
||||
case qwen3ParserStateThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingThinking)
|
||||
|
||||
case qwen3ParserStateCollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
|
||||
// Some qwen3 checkpoints emit an explicit opening <think> tag even
|
||||
// though the prompt already ended with <think>. Strip exactly one
|
||||
// leading opening tag if present.
|
||||
if p.maybeThinkingOpenAtBOL {
|
||||
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
|
||||
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
return events, false
|
||||
}
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
return events, true
|
||||
}
|
||||
if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
|
||||
return events, false
|
||||
}
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
|
||||
if strings.Contains(acc, qwen3ThinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = qwen3ParserStateThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case qwen3ParserStateThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingContent)
|
||||
|
||||
case qwen3ParserStateCollectingContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, qwen3ToolOpenTag) {
|
||||
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwen3EventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen3ToolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen3EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case qwen3ParserStateToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingToolContent)
|
||||
|
||||
case qwen3ParserStateCollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, qwen3ToolCloseTag) {
|
||||
toolContent, _ := p.splitAtTag(qwen3ToolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("qwen3 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, qwen3EventRawToolCall{raw: toolContent})
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
return events, true
|
||||
}
|
||||
return events, false
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
var parsed struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse JSON: %w", err)
|
||||
}
|
||||
|
||||
if parsed.Name == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
_ = tools // qwen3 uses direct JSON args and does not require schema coercion here.
|
||||
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: parsed.Name,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for key, value := range parsed.Arguments {
|
||||
toolCall.Function.Arguments.Set(key, value)
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen3ParserThinkingEnabled(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Let me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingEnabledWithExplicitOpeningTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingEnabledWithSplitOpeningTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("<thi", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if content != "" || thinking != "" || len(calls) != 0 {
|
||||
t.Fatalf("expected no output for first chunk, got content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("nk>Let me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on second chunk: %v", err)
|
||||
}
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingDisabled(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Direct answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Direct answer" {
|
||||
t.Fatalf("expected content %q, got %q", "Direct answer", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserNilThinkDefaultsToContentForInstructParser(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add("Direct answer", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Direct answer" {
|
||||
t.Fatalf("expected content %q, got %q", "Direct answer", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCall(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
input := "<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
location, ok := calls[0].Function.Arguments.Get("location")
|
||||
if !ok || location != "San Francisco" {
|
||||
t.Fatalf("expected location %q, got %v", "San Francisco", location)
|
||||
}
|
||||
unit, ok := calls[0].Function.Arguments.Get("unit")
|
||||
if !ok || unit != "celsius" {
|
||||
t.Fatalf("expected unit %q, got %v", "celsius", unit)
|
||||
}
|
||||
}
|
||||
@@ -30,8 +30,6 @@ type ModelfileConfig struct {
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
Parser string
|
||||
Renderer string
|
||||
}
|
||||
|
||||
// CreateOptions holds all options for model creation.
|
||||
@@ -39,7 +37,7 @@ type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
@@ -269,8 +267,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: caps,
|
||||
Requires: MinOllamaVersion,
|
||||
Parser: resolveParserName(opts.Modelfile, parserName),
|
||||
Renderer: resolveRendererName(opts.Modelfile, rendererName),
|
||||
Parser: parserName,
|
||||
Renderer: rendererName,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
@@ -307,22 +305,6 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
||||
}
|
||||
}
|
||||
|
||||
func resolveParserName(mf *ModelfileConfig, inferred string) string {
|
||||
if mf != nil && mf.Parser != "" {
|
||||
return mf.Parser
|
||||
}
|
||||
|
||||
return inferred
|
||||
}
|
||||
|
||||
func resolveRendererName(mf *ModelfileConfig, inferred string) string {
|
||||
if mf != nil && mf.Renderer != "" {
|
||||
return mf.Renderer
|
||||
}
|
||||
|
||||
return inferred
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
var layers []manifest.Layer
|
||||
@@ -428,7 +410,7 @@ func getParserName(modelDir string) string {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3"
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -442,7 +424,7 @@ func getParserName(modelDir string) string {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3"
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@ func TestModelfileConfig(t *testing.T) {
|
||||
Template: "{{ .Prompt }}",
|
||||
System: "You are a helpful assistant.",
|
||||
License: "MIT",
|
||||
Parser: "qwen3",
|
||||
Renderer: "qwen3",
|
||||
}
|
||||
|
||||
if config.Template != "{{ .Prompt }}" {
|
||||
@@ -23,12 +21,6 @@ func TestModelfileConfig(t *testing.T) {
|
||||
if config.License != "MIT" {
|
||||
t.Errorf("License = %q, want %q", config.License, "MIT")
|
||||
}
|
||||
if config.Parser != "qwen3" {
|
||||
t.Errorf("Parser = %q, want %q", config.Parser, "qwen3")
|
||||
}
|
||||
if config.Renderer != "qwen3" {
|
||||
t.Errorf("Renderer = %q, want %q", config.Renderer, "qwen3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_Empty(t *testing.T) {
|
||||
@@ -43,12 +35,6 @@ func TestModelfileConfig_Empty(t *testing.T) {
|
||||
if config.License != "" {
|
||||
t.Errorf("License should be empty, got %q", config.License)
|
||||
}
|
||||
if config.Parser != "" {
|
||||
t.Errorf("Parser should be empty, got %q", config.Parser)
|
||||
}
|
||||
if config.Renderer != "" {
|
||||
t.Errorf("Renderer should be empty, got %q", config.Renderer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
@@ -67,12 +53,6 @@ func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
if config.License != "" {
|
||||
t.Error("License should be empty")
|
||||
}
|
||||
if config.Parser != "" {
|
||||
t.Error("Parser should be empty")
|
||||
}
|
||||
if config.Renderer != "" {
|
||||
t.Error("Renderer should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinOllamaVersion(t *testing.T) {
|
||||
@@ -118,8 +98,6 @@ func TestCreateOptions(t *testing.T) {
|
||||
Template: "test",
|
||||
System: "system",
|
||||
License: "MIT",
|
||||
Parser: "qwen3-thinking",
|
||||
Renderer: "qwen3",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -138,92 +116,6 @@ func TestCreateOptions(t *testing.T) {
|
||||
if opts.Modelfile.Template != "test" {
|
||||
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
||||
}
|
||||
if opts.Modelfile.Parser != "qwen3-thinking" {
|
||||
t.Errorf("Modelfile.Parser = %q, want %q", opts.Modelfile.Parser, "qwen3-thinking")
|
||||
}
|
||||
if opts.Modelfile.Renderer != "qwen3" {
|
||||
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveParserName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mf *ModelfileConfig
|
||||
inferred string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil modelfile uses inferred",
|
||||
mf: nil,
|
||||
inferred: "qwen3",
|
||||
want: "qwen3",
|
||||
},
|
||||
{
|
||||
name: "empty parser uses inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Parser: "",
|
||||
},
|
||||
inferred: "qwen3",
|
||||
want: "qwen3",
|
||||
},
|
||||
{
|
||||
name: "explicit parser overrides inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Parser: "qwen3-thinking",
|
||||
},
|
||||
inferred: "qwen3",
|
||||
want: "qwen3-thinking",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveParserName(tt.mf, tt.inferred); got != tt.want {
|
||||
t.Fatalf("resolveParserName() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRendererName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mf *ModelfileConfig
|
||||
inferred string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil modelfile uses inferred",
|
||||
mf: nil,
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3-coder",
|
||||
},
|
||||
{
|
||||
name: "empty renderer uses inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Renderer: "",
|
||||
},
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3-coder",
|
||||
},
|
||||
{
|
||||
name: "explicit renderer overrides inferred",
|
||||
mf: &ModelfileConfig{
|
||||
Renderer: "qwen3",
|
||||
},
|
||||
inferred: "qwen3-coder",
|
||||
want: "qwen3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveRendererName(tt.mf, tt.inferred); got != tt.want {
|
||||
t.Fatalf("resolveRendererName() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions_Defaults(t *testing.T) {
|
||||
|
||||
@@ -6,5 +6,4 @@ import (
|
||||
_ "github.com/ollama/ollama/x/models/gemma3"
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
)
|
||||
|
||||
92
x/mlxrunner/model/linear.go
Normal file
92
x/mlxrunner/model/linear.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build mlx
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
|
||||
type LinearFactory struct {
|
||||
tensors map[string]*mlx.Array
|
||||
defaultGroupSize int
|
||||
defaultBits int
|
||||
defaultMode string
|
||||
tensorQuant map[string]*TensorQuantInfo
|
||||
}
|
||||
|
||||
// NewLinearFactory creates a reusable constructor for model linear layers.
|
||||
func NewLinearFactory(
|
||||
tensors map[string]*mlx.Array,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) LinearFactory {
|
||||
return LinearFactory{
|
||||
tensors: tensors,
|
||||
defaultGroupSize: defaultGroupSize,
|
||||
defaultBits: defaultBits,
|
||||
defaultMode: defaultMode,
|
||||
tensorQuant: tensorQuant,
|
||||
}
|
||||
}
|
||||
|
||||
// Make constructs a linear layer at path.
|
||||
func (f LinearFactory) Make(path string) nn.LinearLayer {
|
||||
return MakeLinearLayer(
|
||||
f.tensors,
|
||||
path,
|
||||
f.defaultGroupSize,
|
||||
f.defaultBits,
|
||||
f.defaultMode,
|
||||
f.tensorQuant,
|
||||
)
|
||||
}
|
||||
|
||||
// MakeLinearLayer constructs a linear layer from a tensor map.
|
||||
//
|
||||
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
|
||||
// quant params via TensorQuant metadata (with shape-based affine fallback).
|
||||
// For non-quantized tensors, it returns a standard nn.Linear.
|
||||
func MakeLinearLayer(
|
||||
tensors map[string]*mlx.Array,
|
||||
path string,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) nn.LinearLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
bias := tensors[path+".bias"]
|
||||
|
||||
groupSize, bits, mode := ResolveLinearQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: w,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
bias := tensors[path+".bias"]
|
||||
return nn.NewLinear(w, bias)
|
||||
}
|
||||
@@ -18,15 +18,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
}
|
||||
if enableCompile {
|
||||
mlx.EnableCompile()
|
||||
} else {
|
||||
mlx.DisableCompile()
|
||||
}
|
||||
mlx.EnableCompile()
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
|
||||
|
||||
@@ -1,338 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3 provides the Qwen3 text model implementation for MLX.
|
||||
package qwen3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Qwen3ForCausalLM", newModel)
|
||||
}
|
||||
|
||||
// Config holds Qwen3 model configuration.
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization).
|
||||
QuantGroupSize int `json:"-"`
|
||||
QuantBits int `json:"-"`
|
||||
QuantMode string `json:"-"`
|
||||
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||
|
||||
// Computed fields.
|
||||
Scale float32 `json:"-"`
|
||||
QKNormEps float32 `json:"-"`
|
||||
}
|
||||
|
||||
// Model is the Qwen3 text-only model.
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
Layers []*Layer
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
|
||||
weightPrefix string
|
||||
}
|
||||
|
||||
// Layer is a single Qwen3 decoder block.
|
||||
type Layer struct {
|
||||
Attention *Attention
|
||||
MLP *MLP
|
||||
AttentionNorm *nn.RMSNorm
|
||||
MLPNorm *nn.RMSNorm
|
||||
}
|
||||
|
||||
// Attention implements Qwen3 attention with Q/K norms.
|
||||
type Attention struct {
|
||||
QProj nn.LinearLayer
|
||||
KProj nn.LinearLayer
|
||||
VProj nn.LinearLayer
|
||||
OProj nn.LinearLayer
|
||||
QNorm *nn.RMSNorm
|
||||
KNorm *nn.RMSNorm
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network with SwiGLU activation.
|
||||
type MLP struct {
|
||||
GateProj nn.LinearLayer
|
||||
UpProj nn.LinearLayer
|
||||
DownProj nn.LinearLayer
|
||||
}
|
||||
|
||||
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
|
||||
for _, prefix := range []string{"", "language_model."} {
|
||||
if tensors[prefix+"model.embed_tokens.weight"] != nil {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func newModel(root *model.Root) (base.Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize <= 0 {
|
||||
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumAttentionHeads <= 0 {
|
||||
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
|
||||
}
|
||||
if cfg.NumKeyValueHeads <= 0 {
|
||||
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
|
||||
}
|
||||
if cfg.HeadDim == 0 {
|
||||
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
|
||||
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
|
||||
}
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
}
|
||||
if cfg.HeadDim <= 0 {
|
||||
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
||||
}
|
||||
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
|
||||
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
cfg.QKNormEps = 1e-6
|
||||
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||
if gs := root.GroupSize(); gs > 0 {
|
||||
cfg.QuantGroupSize = gs
|
||||
}
|
||||
} else {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||
}
|
||||
cfg.TensorQuant = root.AllTensorQuant()
|
||||
|
||||
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
tokConfig := &tokenizer.TokenizerConfig{
|
||||
ConfigJSON: configData,
|
||||
}
|
||||
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
||||
// to model fields.
|
||||
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.weightPrefix = resolveWeightPrefix(tensors)
|
||||
prefix := m.weightPrefix
|
||||
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
|
||||
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||
if embedWeight == nil {
|
||||
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||
}
|
||||
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||
|
||||
normWeight := tensors[prefix+"model.norm.weight"]
|
||||
if normWeight == nil {
|
||||
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
|
||||
}
|
||||
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||
|
||||
if m.TieWordEmbeddings {
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||
m.LMHead = lmHead
|
||||
} else {
|
||||
// Qwen3 checkpoints commonly tie output projection to embeddings.
|
||||
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||
}
|
||||
|
||||
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
|
||||
|
||||
layer := &Layer{
|
||||
Attention: &Attention{},
|
||||
MLP: &MLP{},
|
||||
}
|
||||
|
||||
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
|
||||
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
|
||||
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||
}
|
||||
|
||||
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
|
||||
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
|
||||
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
|
||||
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
|
||||
|
||||
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
|
||||
layer.Attention.QNorm = nn.NewRMSNorm(w, m.QKNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
|
||||
layer.Attention.KNorm = nn.NewRMSNorm(w, m.QKNormEps)
|
||||
}
|
||||
|
||||
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
|
||||
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
|
||||
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
|
||||
|
||||
if layer.AttentionNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing input_layernorm", i)
|
||||
}
|
||||
if layer.MLPNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
|
||||
}
|
||||
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
|
||||
return fmt.Errorf("layer %d: missing attention projections", i)
|
||||
}
|
||||
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
|
||||
return fmt.Errorf("layer %d: missing attention q/k norms", i)
|
||||
}
|
||||
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
|
||||
return fmt.Errorf("layer %d: missing mlp projections", i)
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
return m.Norm.Forward(h, m.RMSNormEps)
|
||||
}
|
||||
|
||||
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.LMHead.Forward(x)
|
||||
}
|
||||
|
||||
func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
func (m *Model) NewCaches() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
|
||||
q = a.QNorm.Forward(q, cfg.QKNormEps)
|
||||
k = a.KNorm.Forward(k, cfg.QKNormEps)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
}
|
||||
|
||||
// MLX SDPA supports grouped-query attention directly (Q heads can be a
|
||||
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
|
||||
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -58,7 +59,15 @@ func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return buildModelInfo(config, totalBytes, tensorCount), nil
|
||||
info := buildModelInfo(config, totalBytes, tensorCount)
|
||||
|
||||
// For quantized models, byte-based estimation can significantly undercount
|
||||
// parameters. Prefer exact counting from tensor shapes in safetensors headers.
|
||||
if paramCount, err := getParameterCountFromManifest(mf); err == nil && paramCount > 0 {
|
||||
info["general.parameter_count"] = paramCount
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// buildModelInfo constructs the model info map from config and tensor stats.
|
||||
@@ -151,6 +160,51 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
|
||||
return info
|
||||
}
|
||||
|
||||
// getParameterCountFromManifest counts model parameters from tensor shapes.
|
||||
// This accounts for quantized tensors by using unpacked shapes from
|
||||
// getTensorInfoFromManifest.
|
||||
func getParameterCountFromManifest(mf *manifest.Manifest) (int64, error) {
|
||||
tensors, err := getTensorInfoFromManifest(mf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var total int64
|
||||
for _, tensor := range tensors {
|
||||
if len(tensor.Shape) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
elements := int64(1)
|
||||
for _, dim := range tensor.Shape {
|
||||
if dim == 0 {
|
||||
elements = 0
|
||||
break
|
||||
}
|
||||
|
||||
if dim > uint64(math.MaxInt64) {
|
||||
return 0, fmt.Errorf("tensor %s dimension too large: %d", tensor.Name, dim)
|
||||
}
|
||||
|
||||
d := int64(dim)
|
||||
if elements > math.MaxInt64/d {
|
||||
return 0, fmt.Errorf("tensor %s element count overflow", tensor.Name)
|
||||
}
|
||||
elements *= d
|
||||
}
|
||||
|
||||
if elements == 0 {
|
||||
continue
|
||||
}
|
||||
if total > math.MaxInt64-elements {
|
||||
return 0, fmt.Errorf("total parameter count overflow")
|
||||
}
|
||||
total += elements
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
||||
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
||||
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
|
||||
|
||||
@@ -714,6 +714,187 @@ func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetParameterCountFromManifest(t *testing.T) {
|
||||
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
blobDir := filepath.Join(tempDir, "blobs")
|
||||
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Unquantized tensor: [4,5] = 20 params
|
||||
header1 := map[string]any{
|
||||
"model.embed_tokens.weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{4, 5},
|
||||
"data_offsets": []int64{0, 40},
|
||||
},
|
||||
}
|
||||
header1JSON, _ := json.Marshal(header1)
|
||||
var buf1 bytes.Buffer
|
||||
binary.Write(&buf1, binary.LittleEndian, uint64(len(header1JSON)))
|
||||
buf1.Write(header1JSON)
|
||||
|
||||
digest1 := "sha256:1111111111111111111111111111111111111111111111111111111111111111"
|
||||
blobPath1, err := manifest.BlobsPath(digest1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath1, buf1.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob1: %v", err)
|
||||
}
|
||||
|
||||
// Quantized int4 tensor with packed shape [10,2] -> unpacked [10,16] = 160 params
|
||||
header2 := map[string]any{
|
||||
"__metadata__": map[string]string{
|
||||
"quant_type": "int4",
|
||||
"group_size": "32",
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{10, 2},
|
||||
"data_offsets": []int64{0, 80},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10, 1},
|
||||
"data_offsets": []int64{80, 100},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10, 1},
|
||||
"data_offsets": []int64{100, 120},
|
||||
},
|
||||
}
|
||||
header2JSON, _ := json.Marshal(header2)
|
||||
var buf2 bytes.Buffer
|
||||
binary.Write(&buf2, binary.LittleEndian, uint64(len(header2JSON)))
|
||||
buf2.Write(header2JSON)
|
||||
|
||||
digest2 := "sha256:2222222222222222222222222222222222222222222222222222222222222222"
|
||||
blobPath2, err := manifest.BlobsPath(digest2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath2, buf2.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob2: %v", err)
|
||||
}
|
||||
|
||||
mf := &manifest.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Layers: []manifest.Layer{
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: digest1,
|
||||
Size: int64(buf1.Len() + 40),
|
||||
Name: "model.embed_tokens.weight",
|
||||
},
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: digest2,
|
||||
Size: int64(buf2.Len() + 120),
|
||||
Name: "model.layers.0.mlp.up_proj.weight",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
paramCount, err := getParameterCountFromManifest(mf)
|
||||
if err != nil {
|
||||
t.Fatalf("getParameterCountFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
const want int64 = 180 // 20 + 160
|
||||
if paramCount != want {
|
||||
t.Errorf("parameter_count = %d, want %d", paramCount, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetParameterCountFromManifest_MixedQuantizedPacked(t *testing.T) {
|
||||
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
blobDir := filepath.Join(tempDir, "blobs")
|
||||
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Packed mixed-precision blob (no global metadata):
|
||||
// - gate_proj: int4 packed [5,8] + scale [5,2] => unpacked [5,64] = 320 params
|
||||
// - down_proj: int8 packed [5,16] + scale [5,1] => unpacked [5,64] = 320 params
|
||||
header := map[string]any{
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{5, 8},
|
||||
"data_offsets": []int64{0, 160},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 2},
|
||||
"data_offsets": []int64{160, 180},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 2},
|
||||
"data_offsets": []int64{180, 200},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{5, 16},
|
||||
"data_offsets": []int64{200, 520},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 1},
|
||||
"data_offsets": []int64{520, 530},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{5, 1},
|
||||
"data_offsets": []int64{530, 540},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
digest := "sha256:3333333333333333333333333333333333333333333333333333333333333333"
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob: %v", err)
|
||||
}
|
||||
|
||||
mf := &manifest.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Layers: []manifest.Layer{
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: digest,
|
||||
Size: int64(buf.Len() + 540),
|
||||
Name: "model.layers.0.mlp.experts",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
paramCount, err := getParameterCountFromManifest(mf)
|
||||
if err != nil {
|
||||
t.Fatalf("getParameterCountFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
const want int64 = 640 // 320 + 320
|
||||
if paramCount != want {
|
||||
t.Errorf("parameter_count = %d, want %d", paramCount, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsAllHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user