mirror of
https://github.com/ollama/ollama.git
synced 2026-03-06 16:08:21 -05:00
Compare commits
1 Commits
parth-pi-t
...
jessegross
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d5ff25724 |
@@ -1063,7 +1063,7 @@ func DefaultOptions() Options {
|
||||
TopP: 0.9,
|
||||
TypicalP: 1.0,
|
||||
RepeatLastN: 64,
|
||||
RepeatPenalty: 1.0,
|
||||
RepeatPenalty: 1.1,
|
||||
PresencePenalty: 0.0,
|
||||
FrequencyPenalty: 0.0,
|
||||
Seed: -1,
|
||||
|
||||
@@ -35,7 +35,6 @@ import (
|
||||
var (
|
||||
wv = &Webview{}
|
||||
uiServerPort int
|
||||
appStore *store.Store
|
||||
)
|
||||
|
||||
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
|
||||
@@ -209,7 +208,6 @@ func main() {
|
||||
uiServerPort = port
|
||||
|
||||
st := &store.Store{}
|
||||
appStore = st
|
||||
|
||||
// Enable CORS in development mode
|
||||
if devMode {
|
||||
@@ -296,15 +294,8 @@ func main() {
|
||||
|
||||
// Check for pending updates on startup (show tray notification if update is ready)
|
||||
if updater.IsUpdatePending() {
|
||||
// On Windows, the tray is initialized in osRun(). Calling UpdateAvailable
|
||||
// before that would dereference a nil tray callback.
|
||||
// TODO: refactor so the update check runs after platform init on all platforms.
|
||||
if runtime.GOOS == "windows" {
|
||||
slog.Debug("update pending on startup, deferring tray notification until tray initialization")
|
||||
} else {
|
||||
slog.Debug("update pending on startup, showing tray notification")
|
||||
UpdateAvailable("")
|
||||
}
|
||||
slog.Debug("update pending on startup, showing tray notification")
|
||||
UpdateAvailable("")
|
||||
}
|
||||
|
||||
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
||||
@@ -369,7 +360,8 @@ func startHiddenTasks() {
|
||||
slog.Info("deferring pending update for fast startup")
|
||||
} else {
|
||||
// Check if auto-update is enabled before automatically upgrading
|
||||
settings, err := appStore.Settings()
|
||||
st := &store.Store{}
|
||||
settings, err := st.Settings()
|
||||
if err != nil {
|
||||
slog.Warn("failed to load settings for upgrade check", "error", err)
|
||||
} else if !settings.AutoUpdateEnabled {
|
||||
|
||||
@@ -154,10 +154,6 @@ func handleURLSchemeRequest(urlScheme string) {
|
||||
}
|
||||
|
||||
func UpdateAvailable(ver string) error {
|
||||
if app.t == nil {
|
||||
slog.Debug("tray not yet initialized, skipping update notification")
|
||||
return nil
|
||||
}
|
||||
return app.t.UpdateAvailable(ver)
|
||||
}
|
||||
|
||||
@@ -169,14 +165,6 @@ func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
|
||||
log.Fatalf("Failed to start: %s", err)
|
||||
}
|
||||
|
||||
// Check for pending updates now that the tray is initialized.
|
||||
// The platform-independent check in app.go fires before osRun,
|
||||
// when app.t is still nil, so we must re-check here.
|
||||
if updater.IsUpdatePending() {
|
||||
slog.Debug("update pending on startup, showing tray notification")
|
||||
UpdateAvailable("")
|
||||
}
|
||||
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
|
||||
@@ -289,7 +289,6 @@ func (u *Updater) TriggerImmediateCheck() {
|
||||
|
||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||
u.checkNow = make(chan struct{}, 1)
|
||||
u.checkNow <- struct{}{} // Trigger first check after initial delay
|
||||
go func() {
|
||||
// Don't blast an update message immediately after startup
|
||||
time.Sleep(UpdateCheckInitialDelay)
|
||||
@@ -334,7 +333,7 @@ func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(str
|
||||
continue
|
||||
}
|
||||
|
||||
// Download successful - show tray notification
|
||||
// Download successful - show tray notification (regardless of toggle state)
|
||||
err = cb(resp.UpdateVersion)
|
||||
if err != nil {
|
||||
slog.Warn("failed to register update available with tray", "error", err)
|
||||
|
||||
@@ -351,13 +351,10 @@ func TestTriggerImmediateCheck(t *testing.T) {
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
|
||||
// Wait for the initial check that fires after the initial delay
|
||||
select {
|
||||
case <-checkDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("initial check did not happen")
|
||||
}
|
||||
// Wait for goroutine to start and pass initial delay
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// With 1 hour interval, no check should have happened yet
|
||||
initialCount := checkCount.Load()
|
||||
|
||||
// Trigger immediate check
|
||||
|
||||
11
cmd/cmd.go
11
cmd/cmd.go
@@ -585,6 +585,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
useImagegen := false
|
||||
if cmd.Flags().Lookup("imagegen") != nil {
|
||||
useImagegen, err = cmd.Flags().GetBool("imagegen")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if useImagegen {
|
||||
opts.Options["use_imagegen_runner"] = true
|
||||
}
|
||||
|
||||
// Fill out the rest of the options based on information about the
|
||||
// model.
|
||||
client, err := api.ClientFromEnvironment()
|
||||
|
||||
@@ -122,18 +122,13 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"name": "Ollama (local)",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate legacy provider name
|
||||
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||
ollama["name"] = "Ollama"
|
||||
}
|
||||
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
models = make(map[string]any)
|
||||
|
||||
@@ -232,44 +232,6 @@ func TestOpenCodeEdit(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrate Ollama (local) provider name", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"Ollama (local)","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
if ollama["name"] != "Ollama" {
|
||||
t.Errorf("provider name not migrated: got %q, want %q", ollama["name"], "Ollama")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve custom provider name", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"name":"My Custom Ollama","npm":"@ai-sdk/openai-compatible","options":{"baseURL":"http://localhost:11434/v1"}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
if ollama["name"] != "My Custom Ollama" {
|
||||
t.Errorf("custom provider name was changed: got %q, want %q", ollama["name"], "My Custom Ollama")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
@@ -155,8 +155,6 @@ func (p *Pi) Edit(models []string) error {
|
||||
|
||||
settings["defaultProvider"] = "ollama"
|
||||
settings["defaultModel"] = models[0]
|
||||
// TODO(parthsareen): temporary fix for happy path install. should treat thinking level as true for thinking when needed
|
||||
settings["defaultThinkingLevel"] = "off"
|
||||
|
||||
settingsData, err := json.MarshalIndent(settings, "", " ")
|
||||
if err != nil {
|
||||
|
||||
@@ -437,11 +437,6 @@ func TestPiEdit(t *testing.T) {
|
||||
t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
|
||||
}
|
||||
|
||||
// Verify defaultThinkingLevel is set to off
|
||||
if settings["defaultThinkingLevel"] != "off" {
|
||||
t.Errorf("defaultThinkingLevel = %v, want off", settings["defaultThinkingLevel"])
|
||||
}
|
||||
|
||||
// Verify other fields are preserved
|
||||
if settings["theme"] != "dark" {
|
||||
t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
|
||||
|
||||
@@ -152,9 +152,7 @@ PARAMETER <parameter> <parametervalue>
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.0) | float | repeat_penalty 1.0 |
|
||||
| presence_penalty | Penalizes tokens that have already appeared in the generated text to reduce repetition. (Default: 0.0) | float | presence_penalty 1.5 |
|
||||
| frequency_penalty | Penalizes tokens based on how often they have appeared in the generated text. (Default: 0.0) | float | frequency_penalty 1.0 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
||||
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
||||
|
||||
@@ -41,8 +41,8 @@ type GatedDeltaNet struct {
|
||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||
|
||||
@@ -135,18 +135,6 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
default:
|
||||
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||
}
|
||||
if gdn.SSMDT == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
|
||||
}
|
||||
if gdn.SSMA == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
|
||||
}
|
||||
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
|
||||
}
|
||||
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
|
||||
}
|
||||
|
||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||
@@ -454,10 +442,6 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
||||
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||
|
||||
// Collect chunk outputs and concatenate at the end.
|
||||
// Avoids SET on buffer-less intermediates under partial offload.
|
||||
chunks := make([]ml.Tensor, nChunks)
|
||||
|
||||
for chunk := range nChunks {
|
||||
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
@@ -479,7 +463,14 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
||||
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
||||
|
||||
chunks[chunk] = coreAttnOutChunk
|
||||
v = v.SetInplace(
|
||||
ctx,
|
||||
coreAttnOutChunk,
|
||||
v.Stride(1),
|
||||
v.Stride(2),
|
||||
v.Stride(3),
|
||||
chunk*v.Stride(2),
|
||||
)
|
||||
|
||||
// Update state for next chunk
|
||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
@@ -492,20 +483,6 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
stateT = stateT.Add(ctx, kgdMulVNew)
|
||||
}
|
||||
|
||||
// Use a balanced concat tree so concat work does not balloon on long prompts.
|
||||
for len(chunks) > 1 {
|
||||
merged := make([]ml.Tensor, 0, (len(chunks)+1)/2)
|
||||
for i := 0; i < len(chunks); i += 2 {
|
||||
if i+1 < len(chunks) {
|
||||
merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2))
|
||||
} else {
|
||||
merged = append(merged, chunks[i])
|
||||
}
|
||||
}
|
||||
chunks = merged
|
||||
}
|
||||
v = chunks[0]
|
||||
|
||||
// Final reshape
|
||||
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||
|
||||
|
||||
@@ -437,46 +437,6 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
if m.Options == nil {
|
||||
return fmt.Errorf("qwen3next: missing model options")
|
||||
}
|
||||
if len(m.Layers) != len(m.Options.isRecurrent) {
|
||||
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
if !m.Options.isRecurrent[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
gdn, ok := layer.Operator.(*GatedDeltaNet)
|
||||
if !ok || gdn == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
|
||||
}
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
|
||||
}
|
||||
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
|
||||
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
|
||||
}
|
||||
if gdn.SSMDT == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
|
||||
}
|
||||
if gdn.SSMA == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
|
||||
}
|
||||
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
|
||||
}
|
||||
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
m.positionCache = nil
|
||||
if len(m.mropeSections) > 0 {
|
||||
@@ -490,64 +450,6 @@ var (
|
||||
_ model.MultimodalProcessor = (*Model)(nil)
|
||||
)
|
||||
|
||||
func defaultVHeadReordered(arch string) bool {
|
||||
return arch == "qwen35" || arch == "qwen35moe"
|
||||
}
|
||||
|
||||
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
|
||||
isRecurrent := make([]bool, numLayers)
|
||||
|
||||
hasZero := false
|
||||
hasFull := false
|
||||
for i := range numLayers {
|
||||
if i >= len(headCountKV) {
|
||||
continue
|
||||
}
|
||||
|
||||
if headCountKV[i] == 0 {
|
||||
isRecurrent[i] = true
|
||||
hasZero = true
|
||||
} else {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if hasZero && hasFull {
|
||||
return isRecurrent, nil
|
||||
}
|
||||
if !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||
}
|
||||
|
||||
// Compatibility path: older imports store a scalar KV head count and omit
|
||||
// per-layer recurrent flags. Derive the hybrid layout from the interval.
|
||||
interval := int(fullAttentionInterval)
|
||||
if interval == 0 {
|
||||
interval = min(4, numLayers)
|
||||
}
|
||||
if interval <= 0 {
|
||||
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
|
||||
}
|
||||
if interval > numLayers {
|
||||
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
|
||||
}
|
||||
|
||||
hasZero = false
|
||||
hasFull = false
|
||||
for i := range numLayers {
|
||||
isRecurrent[i] = (i+1)%interval != 0
|
||||
if isRecurrent[i] {
|
||||
hasZero = true
|
||||
} else {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasZero || !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
|
||||
}
|
||||
|
||||
return isRecurrent, nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
layers := make([]Layer, numLayers)
|
||||
@@ -558,14 +460,26 @@ func New(c fs.Config) (model.Model, error) {
|
||||
HeadCountKV() []uint64
|
||||
}
|
||||
|
||||
var isRecurrent []bool
|
||||
var headCountKV []uint64
|
||||
if hc, ok := c.(headCounts); ok {
|
||||
headCountKV = hc.HeadCountKV()
|
||||
}
|
||||
|
||||
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
isRecurrent = make([]bool, numLayers)
|
||||
hasZero := false
|
||||
hasFull := false
|
||||
for i := range numLayers {
|
||||
// If KV head count is 0, it's a recurrent layer
|
||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
||||
isRecurrent[i] = true
|
||||
hasZero = true
|
||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasZero || !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
||||
}
|
||||
|
||||
// Determine if MoE
|
||||
@@ -629,7 +543,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
|
||||
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
|
||||
isRecurrent: isRecurrent,
|
||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||
for _, section := range mropeSections {
|
||||
@@ -641,7 +555,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||
}
|
||||
if opts.numKVHeads == 0 {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
||||
}
|
||||
|
||||
// Calculate cache dimensions
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, false, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, true, true, false, true, true, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, true, false, true, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
|
||||
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
|
||||
if err == nil {
|
||||
t.Fatal("inferRecurrentLayers() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
|
||||
t.Fatalf("unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultVHeadReordered(t *testing.T) {
|
||||
if !defaultVHeadReordered("qwen35") {
|
||||
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
|
||||
}
|
||||
if !defaultVHeadReordered("qwen35moe") {
|
||||
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
|
||||
}
|
||||
if defaultVHeadReordered("qwen3next") {
|
||||
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []Layer{{
|
||||
Operator: &GatedDeltaNet{
|
||||
SSMQKV: &nn.Linear{},
|
||||
SSMQKVGate: &nn.Linear{},
|
||||
SSMBeta: &nn.Linear{},
|
||||
SSMAlpha: &nn.Linear{},
|
||||
},
|
||||
}},
|
||||
Options: &Options{
|
||||
isRecurrent: []bool{true},
|
||||
},
|
||||
}
|
||||
|
||||
err := m.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Validate() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing ssm_dt") {
|
||||
t.Fatalf("unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []Layer{{Operator: &FullAttention{}}},
|
||||
Options: &Options{
|
||||
isRecurrent: []bool{false},
|
||||
},
|
||||
}
|
||||
|
||||
if err := m.Validate(); err != nil {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
}
|
||||
@@ -32,10 +32,9 @@ const (
|
||||
)
|
||||
|
||||
type GLM46Parser struct {
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasToolSupport() bool {
|
||||
@@ -49,7 +48,6 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
|
||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
return tools
|
||||
}
|
||||
|
||||
@@ -91,8 +89,6 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
|
||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case glm46EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
|
||||
@@ -11,7 +11,6 @@ type GLM47Parser struct {
|
||||
|
||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
|
||||
@@ -97,91 +97,3 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexing(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
input := `plan</think>
|
||||
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
|
||||
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
|
||||
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
|
||||
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, nil)
|
||||
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func ParserForName(name string) Parser {
|
||||
case "qwen3-thinking":
|
||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
case "qwen3.5":
|
||||
p = &Qwen35Parser{}
|
||||
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
case "qwen3-coder":
|
||||
p = &Qwen3CoderParser{}
|
||||
case "qwen3-vl-instruct":
|
||||
|
||||
@@ -38,7 +38,6 @@ type Qwen3Parser struct {
|
||||
state qwen3ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
hasThinkingSupport bool
|
||||
defaultThinking bool
|
||||
maybeThinkingOpenAtBOL bool
|
||||
@@ -55,7 +54,6 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
|
||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.buffer.Reset()
|
||||
p.callIndex = 0
|
||||
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
if thinkValue == nil {
|
||||
@@ -108,8 +106,6 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
|
||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
calls = append(calls, toolCall)
|
||||
case qwen3EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
@@ -208,24 +204,6 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
|
||||
thinkingCloseIdx := strings.Index(acc, qwen3ThinkingCloseTag)
|
||||
toolOpenIdx := strings.Index(acc, qwen3ToolOpenTag)
|
||||
|
||||
// If a tool call starts before </think>, treat that as the end of thinking
|
||||
// for parsing purposes and continue in tool-call mode.
|
||||
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
|
||||
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
if strings.Contains(acc, qwen3ThinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
@@ -237,7 +215,7 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := max(overlap(acc, qwen3ThinkingCloseTag), overlap(acc, qwen3ToolOpenTag)); overlapLen > 0 {
|
||||
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
@@ -1,238 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type qwen35ParserState int
|
||||
|
||||
const (
|
||||
qwen35ParserStateCollectingThinking qwen35ParserState = iota
|
||||
qwen35ParserStateThinkingDoneEatingWhitespace
|
||||
qwen35ParserStateCollectingContent
|
||||
)
|
||||
|
||||
const (
|
||||
qwen35ThinkingOpenTag = "<think>"
|
||||
qwen35ThinkingCloseTag = "</think>"
|
||||
)
|
||||
|
||||
// Qwen35Parser handles qwen3.5 reasoning extraction and delegates post-thinking
|
||||
// content (including XML tool calls) to Qwen3CoderParser.
|
||||
type Qwen35Parser struct {
|
||||
toolParser Qwen3CoderParser
|
||||
|
||||
state qwen35ParserState
|
||||
buffer strings.Builder
|
||||
// Some checkpoints may emit an explicit leading <think> even when the
|
||||
// prompt already opened thinking. Strip at most one such tag.
|
||||
allowLeadingThinkOpenTag bool
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.buffer.Reset()
|
||||
p.toolParser = Qwen3CoderParser{}
|
||||
p.toolParser.Init(tools, nil, nil)
|
||||
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
if thinkValue == nil {
|
||||
thinkingEnabled = true
|
||||
}
|
||||
|
||||
assistantPrefill := lastMessage != nil && lastMessage.Role == "assistant" && lastMessage.Content != ""
|
||||
if thinkingEnabled && !assistantPrefill {
|
||||
p.state = qwen35ParserStateCollectingThinking
|
||||
p.allowLeadingThinkOpenTag = true
|
||||
} else {
|
||||
p.state = qwen35ParserStateCollectingContent
|
||||
p.allowLeadingThinkOpenTag = false
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
type qwen35Event interface {
|
||||
isQwen35Event()
|
||||
}
|
||||
|
||||
type qwen35EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen35EventContent) isQwen35Event() {}
|
||||
|
||||
type qwen35EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwen35EventThinkingContent) isQwen35Event() {}
|
||||
|
||||
func (p *Qwen35Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case qwen35EventContent:
|
||||
parsedContent, _, parsedCalls, err := p.toolParser.Add(event.content, done)
|
||||
if err != nil {
|
||||
slog.Warn("qwen3.5 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
contentSb.WriteString(parsedContent)
|
||||
calls = append(calls, parsedCalls...)
|
||||
case qwen35EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) parseEvents() []qwen35Event {
|
||||
var all []qwen35Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []qwen35Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3.5 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||
return splitAtTag(&p.buffer, tag, trimAfter)
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen35ParserState) ([]qwen35Event, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// maybeConsumeLeadingThinkOpenTag handles a single optional leading <think> tag.
|
||||
// Returns (handled, shouldContinueParsingNow).
|
||||
func (p *Qwen35Parser) maybeConsumeLeadingThinkOpenTag(acc string) (bool, bool) {
|
||||
if !p.allowLeadingThinkOpenTag {
|
||||
return false, false
|
||||
}
|
||||
|
||||
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, qwen35ThinkingOpenTag) {
|
||||
after := strings.TrimPrefix(trimmed, qwen35ThinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
return true, false
|
||||
}
|
||||
p.allowLeadingThinkOpenTag = false
|
||||
return true, true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(qwen35ThinkingOpenTag, trimmed) {
|
||||
return true, false
|
||||
}
|
||||
|
||||
p.allowLeadingThinkOpenTag = false
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (p *Qwen35Parser) eat() ([]qwen35Event, bool) {
|
||||
var events []qwen35Event
|
||||
|
||||
switch p.state {
|
||||
case qwen35ParserStateCollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
|
||||
if handled, continueNow := p.maybeConsumeLeadingThinkOpenTag(acc); handled {
|
||||
return events, continueNow
|
||||
}
|
||||
|
||||
if strings.Contains(acc, qwen35ThinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(qwen35ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwen35EventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = qwen35ParserStateThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen35ParserStateCollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen35ThinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen35EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwen35EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
|
||||
case qwen35ParserStateThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(qwen35ParserStateCollectingContent)
|
||||
|
||||
case qwen35ParserStateCollectingContent:
|
||||
if p.buffer.Len() == 0 {
|
||||
return events, false
|
||||
}
|
||||
|
||||
content := p.buffer.String()
|
||||
p.buffer.Reset()
|
||||
if len(content) > 0 {
|
||||
events = append(events, qwen35EventContent{content: content})
|
||||
}
|
||||
return events, false
|
||||
|
||||
default:
|
||||
slog.Warn("qwen3.5 parser entered unknown state; resetting to content mode", "state", p.state)
|
||||
p.state = qwen35ParserStateCollectingContent
|
||||
return events, false
|
||||
}
|
||||
}
|
||||
@@ -1,382 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen35ParserXMLToolCall(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: func() *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
props.Set("days", api.ToolProperty{Type: api.PropertyType{"integer"}})
|
||||
return props
|
||||
}(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: false})
|
||||
input := "<tool_call><function=get_weather><parameter=location>\nSan Francisco\n</parameter><parameter=days>\n3\n</parameter></function></tool_call>"
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
location, ok := calls[0].Function.Arguments.Get("location")
|
||||
if !ok || location != "San Francisco" {
|
||||
t.Fatalf("expected location %q, got %v", "San Francisco", location)
|
||||
}
|
||||
|
||||
days, ok := calls[0].Function.Arguments.Get("days")
|
||||
if !ok || days != 3 {
|
||||
t.Fatalf("expected days %d, got %v", 3, days)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserThinkingWithExplicitOpeningTag(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Let me think..." {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||
}
|
||||
if content != "Answer." {
|
||||
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserAssistantPrefillStartsInContent(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
last := &api.Message{Role: "assistant", Content: "Prefilled response start"}
|
||||
parser.Init(nil, last, nil)
|
||||
|
||||
content, thinking, calls, err := parser.Add(" and continued", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking for assistant prefill continuation, got %q", thinking)
|
||||
}
|
||||
if content != " and continued" {
|
||||
t.Fatalf("expected content %q, got %q", " and continued", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserToolCallEmittedInThinkingIsNotParsed(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: func() *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
return props
|
||||
}(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||
input := `Need weather lookup<tool_call><function=get_weather><parameter=location>
|
||||
SF
|
||||
</parameter></function></tool_call>`
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
expectedThinking := `Need weather lookup<tool_call><function=get_weather><parameter=location>
|
||||
SF
|
||||
</parameter></function></tool_call>`
|
||||
if thinking != expectedThinking {
|
||||
t.Fatalf("expected thinking %q, got %q", expectedThinking, thinking)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls before </think>, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserToolCallAfterThinkingCloseIsParsed(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: func() *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
return props
|
||||
}(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||
input := `Need weather lookup</think><tool_call><function=get_weather><parameter=location>
|
||||
SF
|
||||
</parameter></function></tool_call>`
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "Need weather lookup" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Need weather lookup", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call after </think>, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
location, ok := calls[0].Function.Arguments.Get("location")
|
||||
if !ok || location != "SF" {
|
||||
t.Fatalf("expected location %q, got %v", "SF", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserThinkingDisabledPassesContentThrough(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
content, thinking, calls, err := parser.Add("Plain answer without think close tag.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "Plain answer without think close tag." {
|
||||
t.Fatalf("expected content %q, got %q", "Plain answer without think close tag.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserThinkingDisabledWithCloseTagTreatsAsContent(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
content, thinking, calls, err := parser.Add("</think>Some content after spurious tag.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "</think>Some content after spurious tag." {
|
||||
t.Fatalf("expected content %q, got %q", "</think>Some content after spurious tag.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserLeadingThinkCloseProducesContent(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
content, thinking, calls, err := parser.Add("</think>The final answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||
}
|
||||
if content != "The final answer." {
|
||||
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserStreamingSplitThinkCloseTag(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Reasoning text</thi", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if thinking != "Reasoning text" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Reasoning text", thinking)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("nk>The final answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on second chunk: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
|
||||
}
|
||||
if content != "The final answer." {
|
||||
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserStreamingEatsWhitespaceAfterThinkClose(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Reasoning</think>", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if thinking != "Reasoning" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Reasoning", thinking)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("\n \t", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on whitespace chunk: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no thinking on whitespace chunk, got %q", thinking)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected whitespace after </think> to be eaten, got content %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("The final answer.", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on content chunk: %v", err)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no additional thinking, got %q", thinking)
|
||||
}
|
||||
if content != "The final answer." {
|
||||
t.Fatalf("expected content %q, got %q", "The final answer.", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserThinkingTruncatedWithoutCloseTag(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
t.Fatal("expected qwen3.5 parser")
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
content, thinking, calls, err := parser.Add("Reasoning that never closes", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if thinking != "Reasoning that never closes" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Reasoning that never closes", thinking)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
@@ -146,68 +146,6 @@ func TestQwen3ParserToolCall(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingWithToolCallBeforeThinkingClose(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
input := "Let me think<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "Let me think" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingWithSplitToolOpenTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Let me think<tool_ca", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if content != "" || thinking != "Let me think" || len(calls) != 0 {
|
||||
t.Fatalf(
|
||||
"expected content=%q thinking=%q calls=%d, got content=%q thinking=%q calls=%d",
|
||||
"",
|
||||
"Let me think",
|
||||
0,
|
||||
content,
|
||||
thinking,
|
||||
len(calls),
|
||||
)
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("ll>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on second chunk: %v", err)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
@@ -230,89 +168,3 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexing(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
|
||||
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
|
||||
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,10 +29,9 @@ const (
|
||||
)
|
||||
|
||||
type Qwen3CoderParser struct {
|
||||
state qwenParserState
|
||||
acc strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
state qwenParserState
|
||||
acc strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||
@@ -45,7 +44,6 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
||||
|
||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
return tools // Qwen doesn't modify tools
|
||||
}
|
||||
|
||||
@@ -64,8 +62,6 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case qwenEventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
|
||||
@@ -1035,92 +1035,6 @@ func TestQwenToolCallValueParsing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
|
||||
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
|
||||
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, nil)
|
||||
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenXMLTransform(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
|
||||
@@ -180,22 +180,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
return events, false
|
||||
}
|
||||
case CollectingThinkingContent:
|
||||
acc := p.buffer.String()
|
||||
thinkingCloseIdx := strings.Index(acc, thinkingCloseTag)
|
||||
toolOpenIdx := strings.Index(acc, toolOpenTag)
|
||||
|
||||
// If a tool call starts before </think>, treat that as the end of thinking
|
||||
// for parsing purposes and continue in tool-call mode.
|
||||
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
|
||||
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: before})
|
||||
}
|
||||
p.state = CollectingToolContent
|
||||
return events, true
|
||||
}
|
||||
|
||||
if strings.Contains(acc, thinkingCloseTag) {
|
||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: thinking})
|
||||
@@ -206,13 +191,13 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
p.state = CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := max(overlap(acc, thinkingCloseTag), overlap(acc, toolOpenTag)); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
@@ -220,11 +205,11 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
|
||||
@@ -98,12 +98,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
desc: "nested thinking and tool call (outside thinking, inside tool call)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventThinkingContent{content: "I'm thinking"},
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call"},
|
||||
qwenEventContent{content: "</think>"},
|
||||
},
|
||||
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
|
||||
wantEvents: []qwenEvent{qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm nested tool call</tool_call>"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -113,7 +109,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
{
|
||||
input: "<tool_call>I'm nested tool call<think>I'm thinking</think></tool_call>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call<think>I'm thinking</think>"},
|
||||
qwenEventThinkingContent{content: "<tool_call>I'm nested tool call<think>I'm thinking"},
|
||||
qwenEventContent{content: "</tool_call>"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -124,8 +121,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
{
|
||||
input: "I'm thinking<tool_call>I'm NOT a nested tool call</think></tool_call><tool_call>I'm nested tool call 2<think></tool_call></think>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventThinkingContent{content: "I'm thinking"},
|
||||
qwenEventRawToolCall{raw: "I'm NOT a nested tool call</think>"},
|
||||
qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm NOT a nested tool call"},
|
||||
qwenEventContent{content: "</tool_call>"},
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call 2<think>"},
|
||||
qwenEventContent{content: "</think>"},
|
||||
},
|
||||
|
||||
@@ -8,21 +8,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type GlmOcrRenderer struct {
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
func (r *GlmOcrRenderer) renderContent(message api.Message, imageOffset int) (string, int) {
|
||||
var sb strings.Builder
|
||||
for range message.Images {
|
||||
if r.useImgTags {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||
imageOffset++
|
||||
}
|
||||
}
|
||||
sb.WriteString(message.Content)
|
||||
return sb.String(), imageOffset
|
||||
}
|
||||
type GlmOcrRenderer struct{}
|
||||
|
||||
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
@@ -52,14 +38,11 @@ func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
||||
thinkingExplicitlySet = true
|
||||
}
|
||||
|
||||
imageOffset := 0
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>\n")
|
||||
content, nextOffset := r.renderContent(message, imageOffset)
|
||||
imageOffset = nextOffset
|
||||
sb.WriteString(content)
|
||||
sb.WriteString(message.Content)
|
||||
if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
|
||||
sb.WriteString("/nothink")
|
||||
}
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGlmOcrRenderer_Images(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
renderer *GlmOcrRenderer
|
||||
messages []api.Message
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "use_img_tags_single_image",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0]Describe this image.<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "use_img_tags_multiple_images",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Describe these images.",
|
||||
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0][img-1]Describe these images.<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "multi_turn_increments_image_offset",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "First image",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Processed.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Second image",
|
||||
Images: []api.ImageData{api.ImageData("img2")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0]First image<|assistant|>\n<think></think>\nProcessed.\n<|user|>\n[img-1]Second image<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "default_no_img_tags",
|
||||
renderer: &GlmOcrRenderer{},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "No image tags expected.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\nNo image tags expected.<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "no_images_content_unchanged",
|
||||
renderer: &GlmOcrRenderer{useImgTags: true},
|
||||
messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Text only message.",
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\nText only message.<|assistant|>\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.renderer.Render(tt.messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.expected, got); diff != "" {
|
||||
t.Fatalf("Render() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
qwen35ThinkOpenTag = "<think>"
|
||||
qwen35ThinkCloseTag = "</think>"
|
||||
qwen35ToolPostamble = `
|
||||
</tools>
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||
|
||||
<tool_call>
|
||||
<function=example_function_name>
|
||||
<parameter=example_parameter_1>
|
||||
value_1
|
||||
</parameter>
|
||||
<parameter=example_parameter_2>
|
||||
This is the value for the second parameter
|
||||
that can span
|
||||
multiple lines
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
<IMPORTANT>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||
- Required parameters MUST be specified
|
||||
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
</IMPORTANT>`
|
||||
)
|
||||
|
||||
type Qwen35Renderer struct {
|
||||
isThinking bool
|
||||
|
||||
emitEmptyThinkOnNoThink bool
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
func (r *Qwen35Renderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
||||
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
||||
var subSb strings.Builder
|
||||
for range content.Images {
|
||||
if r.useImgTags {
|
||||
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||
imageOffset++
|
||||
} else {
|
||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||
}
|
||||
}
|
||||
// TODO: support videos
|
||||
|
||||
subSb.WriteString(content.Content)
|
||||
return subSb.String(), imageOffset
|
||||
}
|
||||
|
||||
func splitQwen35ReasoningContent(content, messageThinking string, isThinking bool) (reasoning string, remaining string) {
|
||||
if isThinking && messageThinking != "" {
|
||||
return strings.TrimSpace(messageThinking), content
|
||||
}
|
||||
|
||||
if idx := strings.Index(content, qwen35ThinkCloseTag); idx != -1 {
|
||||
before := content[:idx]
|
||||
if open := strings.LastIndex(before, qwen35ThinkOpenTag); open != -1 {
|
||||
reasoning = before[open+len(qwen35ThinkOpenTag):]
|
||||
} else {
|
||||
reasoning = before
|
||||
}
|
||||
content = strings.TrimLeft(content[idx+len(qwen35ThinkCloseTag):], "\n")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(reasoning), content
|
||||
}
|
||||
|
||||
func (r *Qwen35Renderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
isThinking := r.isThinking
|
||||
if think != nil {
|
||||
isThinking = think.Bool()
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString(imStartTag + "system\n")
|
||||
sb.WriteString("# Tools\n\nYou have access to the following functions:\n\n<tools>")
|
||||
for _, tool := range tools {
|
||||
sb.WriteString("\n")
|
||||
if b, err := marshalWithSpaces(tool); err == nil {
|
||||
sb.Write(b)
|
||||
}
|
||||
}
|
||||
sb.WriteString(qwen35ToolPostamble)
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
systemContent, _ := r.renderContent(messages[0], 0)
|
||||
systemContent = strings.TrimSpace(systemContent)
|
||||
if systemContent != "" {
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(systemContent)
|
||||
}
|
||||
}
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
} else if len(messages) > 0 && messages[0].Role == "system" {
|
||||
systemContent, _ := r.renderContent(messages[0], 0)
|
||||
sb.WriteString(imStartTag + "system\n" + strings.TrimSpace(systemContent) + imEndTag + "\n")
|
||||
}
|
||||
|
||||
multiStepTool := true
|
||||
lastQueryIndex := len(messages) - 1 // so this is the last user message
|
||||
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
message := messages[i]
|
||||
if multiStepTool && message.Role == "user" {
|
||||
content, _ := r.renderContent(message, 0)
|
||||
content = strings.TrimSpace(content)
|
||||
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
|
||||
multiStepTool = false
|
||||
lastQueryIndex = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imageOffset := 0
|
||||
for i, message := range messages {
|
||||
content, nextImageOffset := r.renderContent(message, imageOffset)
|
||||
imageOffset = nextImageOffset
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
lastMessage := i == len(messages)-1
|
||||
prefill := lastMessage && message.Role == "assistant"
|
||||
|
||||
if message.Role == "user" || (message.Role == "system" && i != 0) {
|
||||
sb.WriteString(imStartTag + message.Role + "\n" + content + imEndTag + "\n")
|
||||
} else if message.Role == "assistant" {
|
||||
contentReasoning, content := splitQwen35ReasoningContent(content, message.Thinking, isThinking)
|
||||
|
||||
if isThinking && i > lastQueryIndex {
|
||||
sb.WriteString(imStartTag + message.Role + "\n<think>\n" + contentReasoning + "\n</think>\n\n" + content)
|
||||
} else {
|
||||
sb.WriteString(imStartTag + message.Role + "\n" + content)
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for j, toolCall := range message.ToolCalls {
|
||||
if j == 0 {
|
||||
if strings.TrimSpace(content) != "" {
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString("<tool_call>\n<function=" + toolCall.Function.Name + ">\n")
|
||||
for name, value := range toolCall.Function.Arguments.All() {
|
||||
sb.WriteString("<parameter=" + name + ">\n")
|
||||
sb.WriteString(formatToolCallArgument(value))
|
||||
sb.WriteString("\n</parameter>\n")
|
||||
}
|
||||
sb.WriteString("</function>\n</tool_call>")
|
||||
}
|
||||
}
|
||||
|
||||
if !prefill {
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
} else if message.Role == "tool" {
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString(imStartTag + "user")
|
||||
}
|
||||
sb.WriteString("\n<tool_response>\n" + content + "\n</tool_response>")
|
||||
if i == len(messages)-1 || messages[i+1].Role != "tool" {
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
}
|
||||
|
||||
// prefill at the end
|
||||
if lastMessage && !prefill {
|
||||
sb.WriteString(imStartTag + "assistant\n")
|
||||
if isThinking {
|
||||
sb.WriteString("<think>\n")
|
||||
} else if r.emitEmptyThinkOnNoThink {
|
||||
sb.WriteString("<think>\n\n</think>\n\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -1,389 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen35RendererUsesXMLToolCallingFormat(t *testing.T) {
|
||||
renderer := &Qwen35Renderer{isThinking: true}
|
||||
msgs := []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I'll check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "location", Value: "Paris"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{Role: "user", Content: "Thanks"},
|
||||
}
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{
|
||||
Key: "location",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, err := renderer.Render(msgs, tools, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("render failed: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(got, "<tools>") {
|
||||
t.Fatalf("expected tools section in prompt, got:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, "<function=example_function_name>") {
|
||||
t.Fatalf("expected xml-style tool call instructions, got:\n%s", got)
|
||||
}
|
||||
|
||||
wantToolCall := "<tool_call>\n<function=get_weather>\n<parameter=location>\nParis\n</parameter>\n</function>\n</tool_call>"
|
||||
if !strings.Contains(got, wantToolCall) {
|
||||
t.Fatalf("expected xml tool call payload, got:\n%s", got)
|
||||
}
|
||||
|
||||
toolsIdx := strings.Index(got, "# Tools")
|
||||
systemIdx := strings.Index(got, "You are a helpful assistant.")
|
||||
if toolsIdx == -1 || systemIdx == -1 || systemIdx < toolsIdx {
|
||||
t.Fatalf("expected system prompt appended after tool instructions, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35RendererNoThinkPrefill(t *testing.T) {
|
||||
renderer := &Qwen35Renderer{isThinking: true, emitEmptyThinkOnNoThink: true}
|
||||
msgs := []api.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
}
|
||||
|
||||
got, err := renderer.Render(msgs, nil, &api.ThinkValue{Value: false})
|
||||
if err != nil {
|
||||
t.Fatalf("render failed: %v", err)
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n\n</think>\n\n") {
|
||||
t.Fatalf("expected explicit no-think prefill, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35RendererBackToBackToolCallsAndResponses(t *testing.T) {
|
||||
renderer := &Qwen35Renderer{isThinking: true}
|
||||
|
||||
msgs := []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Run add and multiply."},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I'll run both now.",
|
||||
Thinking: "Need to call add and multiply.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "add",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "a", Value: 2},
|
||||
{Key: "b", Value: 3},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "multiply",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "x", Value: 4},
|
||||
{Key: "y", Value: 5},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "5"},
|
||||
{Role: "tool", Content: "20"},
|
||||
{Role: "user", Content: "Summarize the results."},
|
||||
}
|
||||
|
||||
got, err := renderer.Render(msgs, qwen35MathTools(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("render failed: %v", err)
|
||||
}
|
||||
|
||||
if strings.Contains(got, "Need to call add and multiply.") {
|
||||
t.Fatalf("did not expect historical reasoning block in this sequence, got:\n%s", got)
|
||||
}
|
||||
|
||||
wantToolCalls := `<tool_call>
|
||||
<function=add>
|
||||
<parameter=a>
|
||||
2
|
||||
</parameter>
|
||||
<parameter=b>
|
||||
3
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<function=multiply>
|
||||
<parameter=x>
|
||||
4
|
||||
</parameter>
|
||||
<parameter=y>
|
||||
5
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>`
|
||||
if !strings.Contains(got, wantToolCalls) {
|
||||
t.Fatalf("expected back-to-back tool calls, got:\n%s", got)
|
||||
}
|
||||
|
||||
wantToolResponses := `<|im_start|>user
|
||||
<tool_response>
|
||||
5
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
20
|
||||
</tool_response><|im_end|>`
|
||||
if !strings.Contains(got, wantToolResponses) {
|
||||
t.Fatalf("expected grouped back-to-back tool responses, got:\n%s", got)
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n") {
|
||||
t.Fatalf("expected assistant thinking prefill at end, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35RendererInterleavedThinkingAndTools(t *testing.T) {
|
||||
renderer := &Qwen35Renderer{isThinking: true}
|
||||
|
||||
msgs := []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Plan a picnic in Paris."},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Checking weather first.",
|
||||
Thinking: "Need weather before giving advice.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "location", Value: "Paris"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "22C"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Checking UV too.",
|
||||
Thinking: "Need UV index for sunscreen advice.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_uv",
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{Key: "location", Value: "Paris"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "5"},
|
||||
}
|
||||
|
||||
got, err := renderer.Render(msgs, qwen35WeatherUVTools(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("render failed: %v", err)
|
||||
}
|
||||
|
||||
wantFirstTurn := `<|im_start|>assistant
|
||||
<think>
|
||||
Need weather before giving advice.
|
||||
</think>
|
||||
|
||||
Checking weather first.
|
||||
|
||||
<tool_call>
|
||||
<function=get_weather>
|
||||
<parameter=location>
|
||||
Paris
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>`
|
||||
if !strings.Contains(got, wantFirstTurn) {
|
||||
t.Fatalf("expected first assistant thinking/tool sequence, got:\n%s", got)
|
||||
}
|
||||
|
||||
wantSecondTurn := `<|im_start|>assistant
|
||||
<think>
|
||||
Need UV index for sunscreen advice.
|
||||
</think>
|
||||
|
||||
Checking UV too.
|
||||
|
||||
<tool_call>
|
||||
<function=get_uv>
|
||||
<parameter=location>
|
||||
Paris
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>`
|
||||
if !strings.Contains(got, wantSecondTurn) {
|
||||
t.Fatalf("expected second assistant thinking/tool sequence, got:\n%s", got)
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(got, "<|im_start|>assistant\n<think>\n") {
|
||||
t.Fatalf("expected assistant thinking prefill at end, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35RendererAssistantPrefillWithThinking(t *testing.T) {
|
||||
renderer := &Qwen35Renderer{isThinking: true}
|
||||
msgs := []api.Message{
|
||||
{Role: "user", Content: "Write two words."},
|
||||
{
|
||||
Role: "assistant",
|
||||
Thinking: "Keep it short.",
|
||||
Content: "Hello world",
|
||||
},
|
||||
}
|
||||
|
||||
got, err := renderer.Render(msgs, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("render failed: %v", err)
|
||||
}
|
||||
|
||||
want := `<|im_start|>user
|
||||
Write two words.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<think>
|
||||
Keep it short.
|
||||
</think>
|
||||
|
||||
Hello world`
|
||||
if got != want {
|
||||
t.Fatalf("unexpected prefill output\n--- got ---\n%s\n--- want ---\n%s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func qwen35MathTools() []api.Tool {
|
||||
return []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "add",
|
||||
Description: "Add two numbers",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{
|
||||
Key: "a",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"integer"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "b",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"integer"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"a", "b"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "multiply",
|
||||
Description: "Multiply two numbers",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{
|
||||
Key: "x",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"integer"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "y",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"integer"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"x", "y"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func qwen35WeatherUVTools() []api.Tool {
|
||||
return []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather for a location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{
|
||||
Key: "location",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_uv",
|
||||
Description: "Get UV index for a location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{
|
||||
Key: "location",
|
||||
Value: api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func rendererForName(name string) Renderer {
|
||||
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
case "qwen3.5":
|
||||
renderer := &Qwen35Renderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
|
||||
renderer := &Qwen3VLRenderer{isThinking: true, emitEmptyThinkOnNoThink: true, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
case "cogito":
|
||||
renderer := &CogitoRenderer{isThinking: true}
|
||||
@@ -86,7 +86,7 @@ func rendererForName(name string) Renderer {
|
||||
case "glm-4.7":
|
||||
return &GLM47Renderer{}
|
||||
case "glm-ocr":
|
||||
return &GlmOcrRenderer{useImgTags: RenderImgTags}
|
||||
return &GlmOcrRenderer{}
|
||||
case "lfm2":
|
||||
return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags}
|
||||
case "lfm2-thinking":
|
||||
|
||||
@@ -562,7 +562,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
seq.sampler.Reset()
|
||||
// Skip this sequence but continue processing the rest
|
||||
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
||||
err = nil
|
||||
@@ -693,12 +692,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
// (unless we take down the whole runner).
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
for _, inp := range seq.pendingInputs {
|
||||
if len(inp.Multimodal) != 0 {
|
||||
continue
|
||||
}
|
||||
seq.sampler.Accept(inp.Token)
|
||||
}
|
||||
seq.pendingInputs = []*input.Input{}
|
||||
}
|
||||
|
||||
@@ -899,9 +892,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
req.Options.TopK,
|
||||
req.Options.TopP,
|
||||
req.Options.MinP,
|
||||
req.Options.RepeatPenalty,
|
||||
req.Options.PresencePenalty,
|
||||
req.Options.FrequencyPenalty,
|
||||
req.Options.Seed,
|
||||
grammar,
|
||||
)
|
||||
@@ -948,14 +938,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
seq.sampler.Reset()
|
||||
for _, inp := range seq.cache.Inputs {
|
||||
if len(inp.Multimodal) != 0 {
|
||||
continue
|
||||
}
|
||||
seq.sampler.Accept(inp.Token)
|
||||
}
|
||||
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
|
||||
@@ -16,49 +16,24 @@ type token struct {
|
||||
value float32 // The raw logit or probability from the model
|
||||
}
|
||||
|
||||
const DefaultPenaltyLookback = 64
|
||||
|
||||
type Sampler struct {
|
||||
rng *rand.Rand
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
temperature float32
|
||||
repeat float32
|
||||
presence float32
|
||||
frequency float32
|
||||
history []int32
|
||||
grammar *GrammarSampler
|
||||
}
|
||||
|
||||
func (s *Sampler) Reset() {
|
||||
s.history = s.history[:0]
|
||||
}
|
||||
|
||||
func (s *Sampler) Accept(token int32) {
|
||||
s.history = append(s.history, token)
|
||||
if len(s.history) > DefaultPenaltyLookback {
|
||||
copy(s.history, s.history[len(s.history)-DefaultPenaltyLookback:])
|
||||
s.history = s.history[:DefaultPenaltyLookback]
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
if len(logits) == 0 {
|
||||
return -1, errors.New("sample: no logits provided to sample")
|
||||
}
|
||||
|
||||
counts := tokenCounts(s.history, len(logits))
|
||||
|
||||
tokens := make([]token, len(logits))
|
||||
for i := range logits {
|
||||
value := logits[i]
|
||||
if count := counts[int32(i)]; count > 0 {
|
||||
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||
}
|
||||
|
||||
tokens[i].id = int32(i)
|
||||
tokens[i].value = value
|
||||
tokens[i].value = logits[i]
|
||||
}
|
||||
|
||||
t, err := s.sample(tokens)
|
||||
@@ -80,12 +55,8 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
// we need to reset them before applying the grammar and
|
||||
// sampling again
|
||||
for i := range logits {
|
||||
value := logits[i]
|
||||
if count := counts[int32(i)]; count > 0 {
|
||||
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
|
||||
}
|
||||
tokens[i].id = int32(i)
|
||||
tokens[i].value = value
|
||||
tokens[i].value = logits[i]
|
||||
}
|
||||
s.grammar.Apply(tokens)
|
||||
t, err = s.sample(tokens)
|
||||
@@ -156,7 +127,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32, seed int, grammar *GrammarSampler) Sampler {
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
|
||||
var rng *rand.Rand
|
||||
if seed != -1 {
|
||||
// PCG requires two parameters: sequence and stream
|
||||
@@ -183,19 +154,12 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, repea
|
||||
minP = 1.0
|
||||
}
|
||||
|
||||
if repeatPenalty <= 0 {
|
||||
repeatPenalty = 1.0
|
||||
}
|
||||
|
||||
return Sampler{
|
||||
rng: rng,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
minP: minP,
|
||||
temperature: temperature,
|
||||
repeat: repeatPenalty,
|
||||
presence: presencePenalty,
|
||||
frequency: frequencyPenalty,
|
||||
grammar: grammar,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0.8, 0, 0, 0, 1, 0, 0, 42, nil)
|
||||
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
sampler.Sample(logits)
|
||||
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
|
||||
for _, tc := range configs {
|
||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, 1, 0, 0, tc.seed, nil)
|
||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
|
||||
sampler.Sample(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
|
||||
|
||||
// Test with combined transforms separately - topK influences performance greatly
|
||||
b.Run("TransformCombined", func(b *testing.B) {
|
||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 1, 0, 0, 42, nil)
|
||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
|
||||
logits[i] = float32(rand.Float64()*10 - 5)
|
||||
}
|
||||
|
||||
sampler := NewSampler(0, -1, 0, 0, 1, 0, 0, -1, nil)
|
||||
sampler := NewSampler(0, -1, 0, 0, -1, nil)
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
logits := []float32{-10, 3, -10, -10}
|
||||
sampler := NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
||||
got, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -25,7 +25,7 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
|
||||
logits = []float32{-100, -10, 0, 10}
|
||||
sampler = NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
|
||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -39,7 +39,7 @@ func TestWeighted(t *testing.T) {
|
||||
// Test very high p
|
||||
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||
// Use extremely small topP to filter out all tokens
|
||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 1, 0, 0, 0, nil)
|
||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -52,7 +52,7 @@ func TestWeighted(t *testing.T) {
|
||||
}
|
||||
|
||||
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||
sampler = NewSampler(1, 0, 0.95, 0.05, 1, 0, 0, 0, nil)
|
||||
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got %d", got)
|
||||
@@ -151,8 +151,8 @@ func TestGrammar(t *testing.T) {
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, 1, 0, 0, -1, nil),
|
||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
||||
}
|
||||
|
||||
// Generate random logits for benchmarking
|
||||
|
||||
@@ -25,48 +25,6 @@ func (h *tokenHeap) Pop() any {
|
||||
return x
|
||||
}
|
||||
|
||||
func tokenCounts(history []int32, vocabSize int) map[int32]int {
|
||||
if len(history) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := 0
|
||||
if len(history) > DefaultPenaltyLookback {
|
||||
start = len(history) - DefaultPenaltyLookback
|
||||
}
|
||||
|
||||
counts := make(map[int32]int, len(history)-start)
|
||||
for _, token := range history[start:] {
|
||||
if token < 0 || int(token) >= vocabSize {
|
||||
continue
|
||||
}
|
||||
counts[token]++
|
||||
}
|
||||
|
||||
return counts
|
||||
}
|
||||
|
||||
func applyPenalty(logit float32, count int, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32) float32 {
|
||||
if repeatPenalty != 1.0 {
|
||||
// Preserve ordering for negative logits when applying repeat penalty.
|
||||
if logit < 0 {
|
||||
logit *= repeatPenalty
|
||||
} else {
|
||||
logit /= repeatPenalty
|
||||
}
|
||||
}
|
||||
|
||||
if frequencyPenalty != 0 {
|
||||
logit -= float32(count) * frequencyPenalty
|
||||
}
|
||||
|
||||
if presencePenalty != 0 {
|
||||
logit -= presencePenalty
|
||||
}
|
||||
|
||||
return logit
|
||||
}
|
||||
|
||||
// temperature applies scaling to the logits
|
||||
func temperature(ts []token, temp float32) {
|
||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||
|
||||
@@ -295,86 +295,6 @@ func TestMinP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenCounts(t *testing.T) {
|
||||
history := make([]int32, 70)
|
||||
history[0] = 7
|
||||
history[69] = 7
|
||||
|
||||
counts := tokenCounts(history, 8)
|
||||
if got := counts[7]; got != 1 {
|
||||
t.Fatalf("lookback mismatch: got %d want %d", got, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPenalty(t *testing.T) {
|
||||
logit := applyPenalty(5.0, 3, 1.0, 1.5, 0.5)
|
||||
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||
t.Fatalf("unexpected penalty result: got %f want %f", logit, 2.0)
|
||||
}
|
||||
|
||||
logit = applyPenalty(4.0, 1, 2.0, 0, 0)
|
||||
if math.Abs(float64(logit-2.0)) > 1e-6 {
|
||||
t.Fatalf("unexpected repeat penalty result for positive logits: got %f want %f", logit, 2.0)
|
||||
}
|
||||
|
||||
logit = applyPenalty(-4.0, 1, 2.0, 0, 0)
|
||||
if math.Abs(float64(logit-(-8.0))) > 1e-6 {
|
||||
t.Fatalf("unexpected repeat penalty result for negative logits: got %f want %f", logit, -8.0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSamplerPresencePenalty(t *testing.T) {
|
||||
logits := []float32{0.0, 5.0, 0.0}
|
||||
|
||||
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||
baseline.Accept(1)
|
||||
got, err := baseline.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != 1 {
|
||||
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||
}
|
||||
|
||||
presence := NewSampler(0, 0, 1, 0, 1, 6, 0, -1, nil)
|
||||
presence.Accept(1)
|
||||
got, err = presence.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got == 1 {
|
||||
t.Fatalf("presence penalty did not change repeated token selection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSamplerFrequencyPenalty(t *testing.T) {
|
||||
logits := []float32{0.0, 5.0, 4.0}
|
||||
|
||||
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
|
||||
baseline.Accept(1)
|
||||
baseline.Accept(1)
|
||||
baseline.Accept(1)
|
||||
got, err := baseline.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != 1 {
|
||||
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
|
||||
}
|
||||
|
||||
frequency := NewSampler(0, 0, 1, 0, 1, 0, 1.0, -1, nil)
|
||||
frequency.Accept(1)
|
||||
frequency.Accept(1)
|
||||
frequency.Accept(1)
|
||||
got, err = frequency.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != 2 {
|
||||
t.Fatalf("frequency penalty did not demote repeated token as expected: got %d want %d", got, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTransforms(b *testing.B) {
|
||||
// Generate random logits
|
||||
tokens := make([]token, 1<<16)
|
||||
|
||||
@@ -65,22 +65,11 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
config.Parser = r.Parser
|
||||
config.Requires = r.Requires
|
||||
|
||||
for v, digest := range r.Files {
|
||||
for v := range r.Files {
|
||||
if !fs.ValidPath(v) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||
return
|
||||
}
|
||||
if digest == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, digest := range r.Adapters {
|
||||
if digest == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
|
||||
@@ -71,10 +71,6 @@ type Model struct {
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
func (m *Model) IsMLX() bool {
|
||||
return m.Config.ModelFormat == "safetensors"
|
||||
}
|
||||
|
||||
// Capabilities returns the capabilities that the model supports
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
@@ -30,44 +30,42 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
currMsgIdx := 0
|
||||
|
||||
if truncate {
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
}
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
}
|
||||
}
|
||||
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
}
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
}
|
||||
}
|
||||
|
||||
if ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
if !truncate || ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
}
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -367,33 +366,3 @@ func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
|
||||
t.Fatal("prompt is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "extract text",
|
||||
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
|
||||
},
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Config: model.ConfigV2{Renderer: "glm-ocr"},
|
||||
ProjectorPaths: []string{"vision"},
|
||||
}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||
think := false
|
||||
|
||||
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := len(images), 2; got != want {
|
||||
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1]extract text") {
|
||||
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||
}
|
||||
|
||||
// Deprecated runner override option; ignore if present.
|
||||
useImagegen, _ := requestOpts["use_imagegen_runner"].(bool)
|
||||
delete(requestOpts, "use_imagegen_runner")
|
||||
|
||||
opts, err := s.modelOptions(model, requestOpts)
|
||||
@@ -158,7 +158,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
@@ -484,8 +484,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
// the real chat handler, but doing this as a stopgap to get renderer
|
||||
// support for generate
|
||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -2217,9 +2216,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
truncate := req.Truncate == nil || *req.Truncate
|
||||
if m.IsMLX() {
|
||||
truncate = false
|
||||
}
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
|
||||
@@ -144,37 +144,6 @@ func TestCreateFromBin(t *testing.T) {
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
})
|
||||
|
||||
t.Run("empty file digest", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "my-gguf-model",
|
||||
Files: map[string]string{"0.gguf": ""},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "invalid digest format") {
|
||||
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty adapter digest", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "my-gguf-model",
|
||||
Files: map[string]string{"0.gguf": digest},
|
||||
Adapters: map[string]string{"adapter.gguf": ""},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "invalid digest format") {
|
||||
t.Errorf("expected invalid digest format error, got:\n%s", w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateFromModel(t *testing.T) {
|
||||
|
||||
@@ -33,6 +33,7 @@ type LlmRequest struct {
|
||||
successCh chan *runnerRef
|
||||
errCh chan error
|
||||
schedAttempts uint
|
||||
useImagegen bool
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
@@ -105,7 +106,7 @@ func schedulerModelKey(m *Model) string {
|
||||
}
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
@@ -122,6 +123,7 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
|
||||
sessionDuration: sessionDuration,
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
useImagegen: useImagegen,
|
||||
}
|
||||
|
||||
key := schedulerModelKey(req.model)
|
||||
@@ -229,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.IsMLX() {
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||
// LLM model with safetensors format - use MLX runner
|
||||
if s.loadMLX(pending) {
|
||||
@@ -591,15 +593,20 @@ iGPUScan:
|
||||
return false
|
||||
}
|
||||
|
||||
// loadMLX loads an experimental safetensors model using MLX runners.
|
||||
// Image models use x/imagegen; LLM models use x/mlxrunner.
|
||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
modelName := req.model.ShortName
|
||||
var server llm.LlamaServer
|
||||
var err error
|
||||
|
||||
isImagegen := false
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
server, err = imagegen.NewServer(modelName)
|
||||
server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen)
|
||||
isImagegen = true
|
||||
} else if req.useImagegen {
|
||||
server, err = imagegen.NewServer(modelName, imagegen.ModeLLM)
|
||||
isImagegen = true
|
||||
} else {
|
||||
server, err = mlxrunner.NewClient(modelName)
|
||||
}
|
||||
@@ -621,7 +628,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
isImagegen: slices.Contains(req.model.Config.Capabilities, "image"),
|
||||
isImagegen: isImagegen,
|
||||
sessionDuration: sessionDuration,
|
||||
totalSize: totalSize,
|
||||
vramSize: vramSize,
|
||||
@@ -730,8 +737,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
|
||||
// Check if runner type (imagegen vs mlxrunner) matches what's requested.
|
||||
wantImagegen := slices.Contains(req.model.Config.Capabilities, "image")
|
||||
// Check if runner type (imagegen vs mlxrunner) matches what's requested
|
||||
wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image")
|
||||
if runner.isImagegen != wantImagegen {
|
||||
return true
|
||||
}
|
||||
@@ -757,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
defer cancel()
|
||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
||||
runner.llama.Ping(ctx) != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
slog.Info("b")
|
||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
require.Empty(t, successCh1b)
|
||||
require.Len(t, errCh1b, 1)
|
||||
@@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
|
||||
c.req.model.ModelPath = "bad path"
|
||||
slog.Info("c")
|
||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false)
|
||||
// Starts in pending channel, then should be quickly processed to return an error
|
||||
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
||||
require.Empty(t, successCh1c)
|
||||
@@ -470,7 +470,7 @@ func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
|
||||
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil)
|
||||
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
|
||||
|
||||
require.Empty(t, successCh)
|
||||
require.Empty(t, errCh)
|
||||
@@ -499,7 +499,7 @@ func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqCtx, cancelReq := context.WithCancel(ctx)
|
||||
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil)
|
||||
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
|
||||
cancelReq()
|
||||
|
||||
select {
|
||||
@@ -574,7 +574,7 @@ func TestSchedPrematureExpired(t *testing.T) {
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = scenario1a.newServer
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
s.Run(ctx)
|
||||
select {
|
||||
|
||||
@@ -288,18 +288,6 @@ func normalizeQuantType(quantize string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func isStackedExpertWeight(name string) bool {
|
||||
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
||||
// or "...proj" (pre-stacked packed tensor).
|
||||
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
||||
strings.Contains(name, ".mlp.experts.") ||
|
||||
strings.Contains(name, ".mlp.shared_experts.")
|
||||
}
|
||||
|
||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||
// Returns "" if the tensor should not be quantized.
|
||||
// This implements mixed-precision quantization:
|
||||
@@ -308,25 +296,18 @@ func isStackedExpertWeight(name string) bool {
|
||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - Norms, embeddings, biases, routing gates: no quantization
|
||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
stackedExpert := isStackedExpertWeight(name)
|
||||
|
||||
// Use basic name-based check first
|
||||
if !stackedExpert && !ShouldQuantize(name, "") {
|
||||
if !ShouldQuantize(name, "") {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Quantize standard linear weights (2D). Also allow stacked expert weights (3D),
|
||||
// e.g. qwen switch_mlp / experts combined tensors.
|
||||
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||
var elems int64 = 1
|
||||
for _, d := range shape {
|
||||
elems *= int64(d)
|
||||
}
|
||||
if elems < 1024 {
|
||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
@@ -557,10 +557,6 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
|
||||
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
|
||||
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
|
||||
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||
@@ -623,44 +619,6 @@ func TestExpertGroupPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
||||
gateUp := GetTensorQuantization(
|
||||
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
|
||||
[]int32{64, 22016, 4096},
|
||||
"int4",
|
||||
)
|
||||
if gateUp != "int4" {
|
||||
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
|
||||
}
|
||||
|
||||
down := GetTensorQuantization(
|
||||
"model.layers.1.mlp.experts.down_proj.weight",
|
||||
[]int32{64, 4096, 14336},
|
||||
"int4",
|
||||
)
|
||||
if down != "int8" {
|
||||
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
|
||||
}
|
||||
|
||||
combinedGateUp := GetTensorQuantization(
|
||||
"model.language_model.layers.0.mlp.experts.gate_up_proj",
|
||||
[]int32{256, 1024, 2048},
|
||||
"int8",
|
||||
)
|
||||
if combinedGateUp != "int8" {
|
||||
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
|
||||
}
|
||||
|
||||
combinedDown := GetTensorQuantization(
|
||||
"model.language_model.layers.0.mlp.experts.down_proj",
|
||||
[]int32{256, 2048, 512},
|
||||
"int4",
|
||||
)
|
||||
if combinedDown != "int8" {
|
||||
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
||||
@@ -10,7 +10,17 @@ go build -tags mlx -o engine ./x/imagegen/cmd/engine
|
||||
|
||||
## Text Generation
|
||||
|
||||
Text generation models are no longer supported by this engine.
|
||||
```bash
|
||||
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `-temperature` - sampling temperature (default 0.7)
|
||||
- `-top-p` - nucleus sampling (default 0.9)
|
||||
- `-top-k` - top-k sampling (default 40)
|
||||
|
||||
Supports: Llama, Gemma3, GPT-OSS
|
||||
|
||||
## Image Generation
|
||||
|
||||
|
||||
@@ -18,6 +18,9 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
@@ -167,11 +170,11 @@ func main() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Load image if provided and model supports it.
|
||||
// Load image if provided and model supports it
|
||||
var image *mlx.Array
|
||||
if *imagePath != "" {
|
||||
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
||||
image, err = imagegen.ProcessImage(*imagePath, mm.ImageSize())
|
||||
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
|
||||
if err != nil {
|
||||
log.Fatal("load image:", err)
|
||||
}
|
||||
@@ -233,8 +236,14 @@ func load(modelPath string) (Model, error) {
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case "gpt_oss":
|
||||
return gpt_oss.Load(modelPath)
|
||||
case "gemma3":
|
||||
return gemma3.Load(modelPath)
|
||||
case "gemma3_text":
|
||||
return gemma3.LoadText(modelPath)
|
||||
default:
|
||||
return nil, fmt.Errorf("model type %q is not supported by x/imagegen/cmd/engine", kind)
|
||||
return llama.Load(modelPath)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
420
x/imagegen/llm.go
Normal file
420
x/imagegen/llm.go
Normal file
@@ -0,0 +1,420 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// TextModel is the interface for LLM text generation models.
|
||||
type TextModel interface {
|
||||
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
NewCache(maxSeqLen int32) []cache.Cache
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
VocabSize() int32
|
||||
MaxContextLength() int32
|
||||
NumLayers() int
|
||||
}
|
||||
|
||||
// llmState holds the state for LLM generation
|
||||
type llmState struct {
|
||||
model TextModel
|
||||
}
|
||||
|
||||
var llmMu sync.Mutex
|
||||
|
||||
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
||||
var generationStream *mlx.Stream
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
// Lazy initialization of generationStream
|
||||
if generationStream == nil {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
mlx.SetDefaultStream(orig)
|
||||
}
|
||||
|
||||
// Decoder wraps model + cache for autoregressive generation.
|
||||
// This matches the pattern from cmd/engine/generate.go
|
||||
type Decoder struct {
|
||||
model TextModel
|
||||
caches []cache.Cache
|
||||
vocabSize int32
|
||||
temp float32
|
||||
token *mlx.Array // Current token (kept across iterations)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
}
|
||||
|
||||
func NewDecoder(m TextModel, temp float32) *Decoder {
|
||||
caches := m.NewCache(0)
|
||||
return &Decoder{
|
||||
model: m,
|
||||
caches: caches,
|
||||
vocabSize: m.VocabSize(),
|
||||
temp: temp,
|
||||
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
processed := 0
|
||||
|
||||
// Track old cache state to free after each chunk
|
||||
var oldCacheState []*mlx.Array
|
||||
|
||||
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
||||
for len(inputIDs) > 1 {
|
||||
chunkSize := min(2048, len(inputIDs)-1)
|
||||
if chunkSize <= 0 {
|
||||
break
|
||||
}
|
||||
chunk := inputIDs[:chunkSize]
|
||||
|
||||
// Save old cache state before forward
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
var cacheState []*mlx.Array
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
||||
d.model.Forward(x, d.caches)
|
||||
for _, c := range d.caches {
|
||||
cacheState = append(cacheState, c.State()...)
|
||||
}
|
||||
})
|
||||
mlx.Eval(cacheState...)
|
||||
|
||||
// Free old cache state
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
inputIDs = inputIDs[chunkSize:]
|
||||
processed += chunkSize
|
||||
}
|
||||
|
||||
// Save old cache state before final step
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
// Final token + sampling
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
||||
mlx.Eval(x) // Materialize before any other evals
|
||||
logits := d.model.Forward(x, d.caches)
|
||||
d.token = sample(logits, d.temp, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Free old cache state from before final step
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
mlx.ClearCache()
|
||||
|
||||
return processed + len(inputIDs)
|
||||
}
|
||||
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
d.token = sample(logits, d.temp, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
mlx.Keep(d.token)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
arr.Free()
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// sample samples from logits using temperature scaling
|
||||
func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
|
||||
// Get last position logits: [1, L, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
seqLen := shape[1]
|
||||
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, vocabSize)
|
||||
|
||||
if temp <= 0 || temp < 0.01 {
|
||||
// Greedy decoding
|
||||
return mlx.Argmax(lastLogits, -1, false)
|
||||
}
|
||||
|
||||
// Apply temperature scaling
|
||||
scaled := mlx.DivScalar(lastLogits, temp)
|
||||
return mlx.RandomCategorical(scaled, -1, 1)
|
||||
}
|
||||
|
||||
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
|
||||
func (s *server) loadLLMModel() error {
|
||||
// Load the manifest to get model information
|
||||
modelManifest, err := manifest.LoadManifest(s.modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Detect model architecture from config.json
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
var modelConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &modelConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse config.json: %w", err)
|
||||
}
|
||||
|
||||
arch := ""
|
||||
if len(modelConfig.Architectures) > 0 {
|
||||
arch = modelConfig.Architectures[0]
|
||||
}
|
||||
if arch == "" {
|
||||
arch = modelConfig.ModelType
|
||||
}
|
||||
|
||||
slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType)
|
||||
|
||||
// Load the appropriate model based on architecture
|
||||
var model TextModel
|
||||
archLower := strings.ToLower(arch)
|
||||
|
||||
switch {
|
||||
case strings.Contains(archLower, "glm4moelite"):
|
||||
m, err := glm4_moe_lite.LoadFromManifest(modelManifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
|
||||
}
|
||||
model = m
|
||||
slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers())
|
||||
|
||||
default:
|
||||
return fmt.Errorf("LLM architecture %q is not yet supported. "+
|
||||
"Supported architectures: glm4-moe-lite. "+
|
||||
"Please convert your model to GGUF format or use a supported architecture", arch)
|
||||
}
|
||||
|
||||
s.llmModel = &llmState{
|
||||
model: model,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleLLMCompletion handles LLM text generation requests.
|
||||
func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) {
|
||||
if s.llmModel == nil {
|
||||
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Serialize generation requests
|
||||
llmMu.Lock()
|
||||
defer llmMu.Unlock()
|
||||
|
||||
if err := s.llmGenerate(w, r, req); err != nil {
|
||||
slog.Error("LLM generation failed", "error", err)
|
||||
// Don't send error if we've already started streaming
|
||||
}
|
||||
}
|
||||
|
||||
// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine
|
||||
func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error {
|
||||
state := s.llmModel
|
||||
|
||||
// Set up streaming response
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
tok := state.model.Tokenizer()
|
||||
|
||||
// The prompt is already formatted by the server using the model's renderer
|
||||
// (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here.
|
||||
prompt := req.Prompt
|
||||
|
||||
// Tokenize the prompt
|
||||
inputIDs := tok.Encode(prompt, true)
|
||||
slog.Debug("tokenized prompt", "num_tokens", len(inputIDs))
|
||||
|
||||
// Generation parameters
|
||||
maxTokens := int(state.model.MaxContextLength())
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
if req.Options != nil && req.Options.NumPredict > 0 {
|
||||
maxTokens = req.Options.NumPredict
|
||||
}
|
||||
|
||||
temperature := float32(0.7)
|
||||
if req.Options != nil && req.Options.Temperature > 0 {
|
||||
temperature = float32(req.Options.Temperature)
|
||||
}
|
||||
|
||||
// Enable MLX compilation for better performance
|
||||
mlx.EnableCompile()
|
||||
|
||||
// Create decoder with fresh caches
|
||||
dec := NewDecoder(state.model, temperature)
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(inputIDs)
|
||||
// Prefill measurement includes time to first token
|
||||
firstToken := dec.step()
|
||||
prefillDuration := time.Since(prefillStart)
|
||||
promptEvalDuration := prefillDuration
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
ctx := r.Context()
|
||||
generated := 0
|
||||
stopReason := "max_tokens"
|
||||
|
||||
// Handle first token
|
||||
generated++
|
||||
if tok.IsEOS(firstToken) {
|
||||
resp := Response{
|
||||
Done: true,
|
||||
StopReason: fmt.Sprintf("first_token_eos:%d", firstToken),
|
||||
PromptEvalCount: prefillTokens,
|
||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
||||
}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
text := tok.Decode([]int32{firstToken})
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
genStart := time.Now()
|
||||
|
||||
// Generation loop
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stopReason = fmt.Sprintf("context_cancelled:%d", generated)
|
||||
break
|
||||
default:
|
||||
}
|
||||
if stopReason != "max_tokens" {
|
||||
break
|
||||
}
|
||||
|
||||
token := dec.step()
|
||||
generated++
|
||||
|
||||
if tok.IsEOS(token) {
|
||||
stopReason = fmt.Sprintf("eos_token:%d", token)
|
||||
break
|
||||
}
|
||||
|
||||
text := tok.Decode([]int32{token})
|
||||
|
||||
// Check for stop sequences
|
||||
if req.Options != nil && len(req.Options.Stop) > 0 {
|
||||
shouldStop := false
|
||||
var matchedStop string
|
||||
for _, stop := range req.Options.Stop {
|
||||
if strings.Contains(text, stop) {
|
||||
text = strings.Split(text, stop)[0]
|
||||
shouldStop = true
|
||||
matchedStop = stop
|
||||
break
|
||||
}
|
||||
}
|
||||
if shouldStop {
|
||||
if text != "" {
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
}
|
||||
stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
// Periodically clear MLX cache
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up
|
||||
mlx.ClearCache()
|
||||
|
||||
// Send final response with stats
|
||||
evalDuration := time.Since(genStart)
|
||||
resp = Response{
|
||||
Done: true,
|
||||
StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated),
|
||||
PromptEvalCount: prefillTokens,
|
||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
||||
EvalCount: generated,
|
||||
EvalDuration: int(evalDuration.Nanoseconds()),
|
||||
}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
614
x/imagegen/models/gemma3/gemma3.go
Normal file
614
x/imagegen/models/gemma3/gemma3.go
Normal file
@@ -0,0 +1,614 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// TextConfig holds configuration for the text model
|
||||
type TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
||||
|
||||
// Computed fields
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
// TextModel is the Gemma 3 text-only model
|
||||
type TextModel struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*DecoderLayer `weight:"model.layers"`
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
Output *nn.Linear `weight:"-"` // Tied to EmbedTokens, set manually
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm to avoid allocation per forward
|
||||
NormScaled *mlx.Array `weight:"-"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*TextConfig
|
||||
}
|
||||
|
||||
// DecoderLayer is a single transformer block
|
||||
type DecoderLayer struct {
|
||||
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
Attention *Attention
|
||||
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
PreFFNorm *nn.RMSNorm `weight:"pre_feedforward_layernorm"`
|
||||
MLP *MLP
|
||||
PostFFNorm *nn.RMSNorm `weight:"post_feedforward_layernorm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
InputNormScaled *mlx.Array `weight:"-"`
|
||||
PostAttnNormScaled *mlx.Array `weight:"-"`
|
||||
PreFFNormScaled *mlx.Array `weight:"-"`
|
||||
PostFFNormScaled *mlx.Array `weight:"-"`
|
||||
|
||||
// Whether this layer uses sliding window attention
|
||||
IsSliding bool
|
||||
LayerIdx int32
|
||||
}
|
||||
|
||||
// Attention implements Gemma 3 attention with Q/K normalization
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"self_attn.q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"self_attn.k_norm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
QNormScaled *mlx.Array `weight:"-"`
|
||||
KNormScaled *mlx.Array `weight:"-"`
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network with GELU activation
|
||||
type MLP struct {
|
||||
GateProj *nn.Linear `weight:"mlp.gate_proj"`
|
||||
UpProj *nn.Linear `weight:"mlp.up_proj"`
|
||||
DownProj *nn.Linear `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
// LoadText loads the text-only Gemma 3 model
|
||||
func LoadText(modelPath string) (*TextModel, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg TextConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Compute scale
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
// Set defaults if not specified
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
if cfg.RopeLocalBaseFreq == 0 {
|
||||
cfg.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &TextModel{
|
||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||
TextConfig: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Initialize layer metadata
|
||||
for i := range m.Layers {
|
||||
m.Layers[i] = &DecoderLayer{
|
||||
LayerIdx: int32(i),
|
||||
IsSliding: isLayerSliding(int32(i), cfg.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Tied embeddings for output
|
||||
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
// Precompute (1 + weight) for Gemma-style RMSNorm to avoid per-forward allocation
|
||||
precomputeGemmaScaledWeights(m)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// precomputeGemmaScaledWeights computes (1 + weight) for all RMSNorm layers
|
||||
// This avoids creating temporary arrays on every forward pass
|
||||
func precomputeGemmaScaledWeights(m *TextModel) {
|
||||
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
|
||||
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
||||
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
|
||||
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
|
||||
|
||||
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
||||
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
||||
}
|
||||
|
||||
// Eval all the precomputed weights
|
||||
var scaled []*mlx.Array
|
||||
scaled = append(scaled, m.NormScaled)
|
||||
for _, layer := range m.Layers {
|
||||
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
|
||||
layer.PreFFNormScaled, layer.PostFFNormScaled,
|
||||
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
|
||||
}
|
||||
mlx.Eval(scaled...)
|
||||
}
|
||||
|
||||
// isLayerSliding determines if a layer uses sliding window attention
|
||||
// Pattern N means: layers 0 to N-1 sliding, N full, N+1 to 2N-1 sliding, 2N full, etc.
|
||||
func isLayerSliding(layerIdx, pattern int32) bool {
|
||||
if pattern <= 0 {
|
||||
return false // No sliding window
|
||||
}
|
||||
// Layer is full attention if (layerIdx + 1) % pattern == 0
|
||||
return (layerIdx+1)%pattern != 0
|
||||
}
|
||||
|
||||
// Forward runs the text model forward pass
|
||||
func (m *TextModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
|
||||
// Get embeddings and scale by sqrt(hidden_size)
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.TextConfig)
|
||||
}
|
||||
|
||||
// Final norm and output projection
|
||||
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.RMSNormEps))
|
||||
}
|
||||
|
||||
// Forward runs a decoder layer
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
// Pre-attention norm (use precomputed scaled weight)
|
||||
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// Attention
|
||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
|
||||
// Post-attention norm and residual
|
||||
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
// Pre-FFN norm
|
||||
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// MLP
|
||||
mlpOut := l.MLP.Forward(normed)
|
||||
|
||||
// Post-FFN norm and residual
|
||||
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
// Forward runs attention with Q/K normalization
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
// Q/K normalization after reshaping (use precomputed scaled weight)
|
||||
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// Apply RoPE with appropriate theta
|
||||
ropeTheta := cfg.RopeTheta
|
||||
if isSliding {
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
|
||||
// Update cache
|
||||
k, v = c.Update(k, v, int(L))
|
||||
|
||||
// Repeat K/V for GQA if needed
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = nn.RepeatKV(k, repeatFactor)
|
||||
v = nn.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// Attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// compiledGeluApprox is a singleton compiled GELU function shared across all layers
|
||||
var compiledGeluApprox *mlx.CompiledFunc
|
||||
|
||||
// getCompiledGeluApprox returns the compiled GELU function, creating it once if needed
|
||||
func getCompiledGeluApprox() *mlx.CompiledFunc {
|
||||
if compiledGeluApprox == nil {
|
||||
compiledGeluApprox = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
return []*mlx.Array{geluApproxImpl(inputs[0])}
|
||||
}, true)
|
||||
}
|
||||
return compiledGeluApprox
|
||||
}
|
||||
|
||||
// Forward runs the MLP with GELU approximation (tanh variant)
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := getCompiledGeluApprox().Call(m.GateProj.Forward(x))[0]
|
||||
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// geluApproxImpl computes GELU using the tanh approximation (gelu_pytorch_tanh):
|
||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
func geluApproxImpl(x *mlx.Array) *mlx.Array {
|
||||
// Constants
|
||||
const sqrt2OverPi = 0.7978845608028654 // sqrt(2/pi)
|
||||
const coeff = 0.044715
|
||||
|
||||
// x^3
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
// x + 0.044715 * x^3
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
|
||||
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
||||
scaled := mlx.MulScalar(inner, sqrt2OverPi)
|
||||
// tanh(...)
|
||||
tanh := mlx.Tanh(scaled)
|
||||
// 1 + tanh(...)
|
||||
onePlusTanh := mlx.AddScalar(tanh, 1.0)
|
||||
// 0.5 * x * (1 + tanh(...))
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusTanh)
|
||||
}
|
||||
|
||||
// gemmaRMSNorm applies Gemma-style RMS normalization: x * rsqrt(mean(x^2) + eps) * (1 + weight)
|
||||
// Uses mlx.RMSNorm fast kernel with pre-computed (1 + weight)
|
||||
func gemmaRMSNorm(x, weight *mlx.Array, eps float32) *mlx.Array {
|
||||
// Gemma uses (1 + weight) instead of weight
|
||||
scaledWeight := mlx.AddScalar(weight, 1.0)
|
||||
return mlx.RMSNorm(x, scaledWeight, eps)
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
func (m *TextModel) NumLayers() int { return len(m.Layers) }
|
||||
func (m *TextModel) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
func (m *TextModel) VocabSize() int32 { return m.TextConfig.VocabSize }
|
||||
|
||||
// Tokenizer returns the tokenizer wrapped to add BOS and apply chat template
|
||||
func (m *TextModel) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
// FormatPrompt applies the Gemma 3 chat template to a prompt
|
||||
func (m *TextModel) FormatPrompt(prompt string) string {
|
||||
// Gemma 3 chat format: <start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
func (m *TextModel) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
if m.Layers[i].IsSliding {
|
||||
// Use rotating cache for sliding window layers
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
||||
} else {
|
||||
// Use regular cache for global attention layers
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// Config holds config for the full multimodal model
|
||||
type Config struct {
|
||||
TextConfig TextConfig `json:"text_config"`
|
||||
VisionConfig VisionConfig `json:"vision_config"`
|
||||
|
||||
// Image token config (from config.json)
|
||||
BOITokenIndex int32 `json:"boi_token_index"` // <start_of_image> = 255999
|
||||
EOITokenIndex int32 `json:"eoi_token_index"` // <end_of_image> = 256000
|
||||
ImageTokenIndex int32 `json:"image_token_index"` // <image_soft_token> = 262144
|
||||
MMTokensPerImage int32 `json:"mm_tokens_per_image"` // 256
|
||||
}
|
||||
|
||||
// Model is the full Gemma 3 multimodal model
|
||||
type Model struct {
|
||||
VisionTower *VisionTower `weight:"vision_tower"`
|
||||
Projector *MultiModalProjector `weight:"multi_modal_projector"`
|
||||
TextModel *TextModel `weight:"language_model"`
|
||||
Config *Config
|
||||
tok *tokenizer.Tokenizer
|
||||
}
|
||||
|
||||
// Load loads the full multimodal Gemma 3 model
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Set defaults for text config (multimodal config often has incomplete text_config)
|
||||
// These defaults match transformers.Gemma3TextConfig defaults
|
||||
tc := &cfg.TextConfig
|
||||
if tc.HeadDim == 0 {
|
||||
tc.HeadDim = 256 // Gemma 3 uses head_dim=256
|
||||
}
|
||||
if tc.NumAttentionHeads == 0 {
|
||||
// Gemma 3 4B uses 8 attention heads (cannot infer from hidden_size/head_dim)
|
||||
tc.NumAttentionHeads = 8
|
||||
}
|
||||
if tc.NumKeyValueHeads == 0 {
|
||||
// Gemma 3 4B uses 4 KV heads (GQA with 2:1 ratio)
|
||||
tc.NumKeyValueHeads = 4
|
||||
}
|
||||
if tc.VocabSize == 0 {
|
||||
tc.VocabSize = 262208 // Gemma 3 vocab size (not 262144!)
|
||||
}
|
||||
if tc.RopeTheta == 0 {
|
||||
tc.RopeTheta = 1000000
|
||||
}
|
||||
if tc.RopeLocalBaseFreq == 0 {
|
||||
tc.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if tc.RMSNormEps == 0 {
|
||||
tc.RMSNormEps = 1e-6
|
||||
}
|
||||
if tc.SlidingWindowPattern == 0 {
|
||||
tc.SlidingWindowPattern = 6
|
||||
}
|
||||
if tc.MaxPositionEmbeddings == 0 {
|
||||
tc.MaxPositionEmbeddings = 131072 // Gemma 3 4B default
|
||||
}
|
||||
|
||||
// Compute text model scale
|
||||
tc.Scale = float32(1.0 / math.Sqrt(float64(tc.HeadDim)))
|
||||
|
||||
// Set defaults for image token config
|
||||
if cfg.BOITokenIndex == 0 {
|
||||
cfg.BOITokenIndex = 255999 // <start_of_image>
|
||||
}
|
||||
if cfg.EOITokenIndex == 0 {
|
||||
cfg.EOITokenIndex = 256000 // <end_of_image>
|
||||
}
|
||||
if cfg.ImageTokenIndex == 0 {
|
||||
cfg.ImageTokenIndex = 262144 // <image_soft_token>
|
||||
}
|
||||
if cfg.MMTokensPerImage == 0 {
|
||||
cfg.MMTokensPerImage = 256
|
||||
}
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
VisionTower: &VisionTower{
|
||||
Embeddings: &VisionEmbeddings{},
|
||||
Encoder: make([]*VisionEncoderLayer, cfg.VisionConfig.NumHiddenLayers),
|
||||
Config: &cfg.VisionConfig,
|
||||
},
|
||||
Projector: &MultiModalProjector{},
|
||||
TextModel: &TextModel{
|
||||
Layers: make([]*DecoderLayer, cfg.TextConfig.NumHiddenLayers),
|
||||
TextConfig: &cfg.TextConfig,
|
||||
},
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Initialize text layer metadata
|
||||
for i := range m.TextModel.Layers {
|
||||
m.TextModel.Layers[i] = &DecoderLayer{
|
||||
LayerIdx: int32(i),
|
||||
IsSliding: isLayerSliding(int32(i), cfg.TextConfig.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize vision encoder layers
|
||||
for i := range m.VisionTower.Encoder {
|
||||
m.VisionTower.Encoder[i] = &VisionEncoderLayer{}
|
||||
}
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Tied embeddings for text output
|
||||
m.TextModel.Output = nn.NewLinear(m.TextModel.EmbedTokens.Weight, nil)
|
||||
m.TextModel.tok = tok
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
// Precompute (1 + weight) for Gemma-style RMSNorm
|
||||
precomputeGemmaScaledWeights(m.TextModel)
|
||||
|
||||
// Precompute projector's scaled weight
|
||||
m.Projector.SoftEmbNormScaled = mlx.AddScalar(m.Projector.SoftEmbNorm.Weight, 1.0)
|
||||
mlx.Eval(m.Projector.SoftEmbNormScaled)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the text-only forward pass
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
return m.TextModel.Forward(tokens, caches)
|
||||
}
|
||||
|
||||
// ForwardWithImage runs the multimodal forward pass
|
||||
// tokens: [B, L] input token IDs (with image placeholder tokens)
|
||||
// image: [B, H, W, C] preprocessed image tensor
|
||||
func (m *Model) ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
cfg := m.Config.TextConfig
|
||||
|
||||
// Find image token position FIRST before any eval that might free tokens
|
||||
imageStartPos := int32(-1)
|
||||
if image != nil && B == 1 {
|
||||
tokenData := tokens.DataInt32() // This evals tokens
|
||||
for i, t := range tokenData {
|
||||
if t == m.Config.ImageTokenIndex {
|
||||
imageStartPos = int32(i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get text embeddings and scale
|
||||
h := m.TextModel.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(cfg.HiddenSize))))
|
||||
|
||||
// Process image if provided
|
||||
if image != nil && imageStartPos >= 0 {
|
||||
// Vision tower: [B, H, W, C] -> [B, num_patches, vision_hidden]
|
||||
visionFeatures := m.VisionTower.Forward(image)
|
||||
|
||||
// Project to text space: [B, num_patches, vision_hidden] -> [B, 256, text_hidden]
|
||||
imageEmbeds := m.Projector.Forward(visionFeatures, cfg.RMSNormEps)
|
||||
|
||||
// Eval h and imageEmbeds together so neither gets freed
|
||||
mlx.Eval(h, imageEmbeds)
|
||||
|
||||
// Cast imageEmbeds to match text embeddings dtype (bf16)
|
||||
if imageEmbeds.Dtype() != h.Dtype() {
|
||||
imageEmbeds = mlx.AsType(imageEmbeds, h.Dtype())
|
||||
mlx.Eval(imageEmbeds)
|
||||
}
|
||||
|
||||
// Insert image embeddings at the known position
|
||||
h = m.insertImageEmbeddingsAt(h, imageEmbeds, imageStartPos)
|
||||
}
|
||||
|
||||
// Run through text model layers
|
||||
for i, layer := range m.TextModel.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.TextModel.TextConfig)
|
||||
}
|
||||
|
||||
// Final norm and output projection
|
||||
return m.TextModel.Output.Forward(mlx.RMSNorm(h, m.TextModel.NormScaled, cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
// insertImageEmbeddingsAt replaces image placeholder tokens with actual image embeddings
|
||||
// at a known position (to avoid re-scanning tokens after eval)
|
||||
// textEmbeds: [B, L, hidden_size] text embeddings
|
||||
// imageEmbeds: [B, 256, hidden_size] image embeddings from projector
|
||||
// startPos: starting position of image tokens in the sequence
|
||||
func (m *Model) insertImageEmbeddingsAt(textEmbeds, imageEmbeds *mlx.Array, startPos int32) *mlx.Array {
|
||||
numImageTokens := imageEmbeds.Shape()[1]
|
||||
L := textEmbeds.Shape()[1]
|
||||
|
||||
// Split text embeddings: [0:startPos] + imageEmbeds + [startPos+256:L]
|
||||
afterStart := startPos + numImageTokens
|
||||
|
||||
// Slice before image tokens: textEmbeds[:, 0:startPos, :]
|
||||
before := mlx.SliceAxis(textEmbeds, 1, 0, startPos)
|
||||
|
||||
// Slice after image tokens: textEmbeds[:, startPos+256:L, :]
|
||||
after := mlx.SliceAxis(textEmbeds, 1, afterStart, L)
|
||||
|
||||
// Concatenate: before + imageEmbeds + after along axis 1
|
||||
return mlx.Concatenate([]*mlx.Array{before, imageEmbeds, after}, 1)
|
||||
}
|
||||
|
||||
// Interface methods for Model
|
||||
func (m *Model) NumLayers() int { return len(m.TextModel.Layers) }
|
||||
func (m *Model) MaxContextLength() int32 { return m.Config.TextConfig.MaxPositionEmbeddings }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.TextConfig.VocabSize }
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { return m.TextModel.NewCache(maxSeqLen) }
|
||||
func (m *Model) ImageSize() int32 { return m.Config.VisionConfig.ImageSize }
|
||||
|
||||
// FormatPrompt applies the Gemma 3 multimodal chat template
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
// FormatPromptWithImage applies the Gemma 3 multimodal chat template with image
|
||||
func (m *Model) FormatPromptWithImage(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n<start_of_image>%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
// ExpandImageTokens expands <start_of_image> into 256 image placeholder tokens
|
||||
// Input tokens containing boi_token (255999) are expanded to:
|
||||
// boi_token + 256 * image_token + eoi_token
|
||||
func (m *Model) ExpandImageTokens(tokens []int32) []int32 {
|
||||
result := make([]int32, 0, len(tokens)+int(m.Config.MMTokensPerImage)+1)
|
||||
|
||||
for _, t := range tokens {
|
||||
if t == m.Config.BOITokenIndex {
|
||||
// Expand: boi + 256 * image_token + eoi
|
||||
result = append(result, m.Config.BOITokenIndex)
|
||||
for i := int32(0); i < m.Config.MMTokensPerImage; i++ {
|
||||
result = append(result, m.Config.ImageTokenIndex)
|
||||
}
|
||||
result = append(result, m.Config.EOITokenIndex)
|
||||
} else {
|
||||
result = append(result, t)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build mlx
|
||||
|
||||
package imagegen
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
"golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
// ProcessImage loads and preprocesses an image for multimodal vision towers.
|
||||
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP.
|
||||
// ProcessImage loads and preprocesses an image for the vision tower
|
||||
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
|
||||
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
@@ -30,20 +30,20 @@ func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
||||
return ProcessImageData(img, imageSize)
|
||||
}
|
||||
|
||||
// ProcessImageData preprocesses an image.Image for multimodal vision towers.
|
||||
// ProcessImageData preprocesses an image.Image for the vision tower
|
||||
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
||||
// Resize to target size using bilinear interpolation.
|
||||
// Resize to target size using bilinear interpolation
|
||||
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
|
||||
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
// Convert to float32 array [H, W, C] and normalize.
|
||||
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0.
|
||||
// Convert to float32 array [H, W, C] and normalize
|
||||
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
|
||||
data := make([]float32, imageSize*imageSize*3)
|
||||
idx := 0
|
||||
for y := int32(0); y < imageSize; y++ {
|
||||
for x := int32(0); x < imageSize; x++ {
|
||||
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
|
||||
// RGBA returns 16-bit values, convert to 8-bit.
|
||||
// RGBA returns 16-bit values, convert to 8-bit
|
||||
data[idx] = float32(r>>8)/127.5 - 1.0
|
||||
data[idx+1] = float32(g>>8)/127.5 - 1.0
|
||||
data[idx+2] = float32(b>>8)/127.5 - 1.0
|
||||
@@ -51,8 +51,8 @@ func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create MLX array [1, H, W, C] for NHWC layout.
|
||||
// Create MLX array [1, H, W, C] for NHWC layout
|
||||
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
|
||||
mlx.Eval(arr) // Materialize to prevent use-after-free.
|
||||
mlx.Eval(arr) // Materialize to prevent use-after-free
|
||||
return arr, nil
|
||||
}
|
||||
50
x/imagegen/models/gemma3/projector.go
Normal file
50
x/imagegen/models/gemma3/projector.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
)
|
||||
|
||||
// MultiModalProjector projects vision features to text embedding space
|
||||
type MultiModalProjector struct {
|
||||
// mm_input_projection_weight: [vision_hidden, text_hidden]
|
||||
InputProjection *mlx.Array `weight:"mm_input_projection_weight"`
|
||||
SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
SoftEmbNormScaled *mlx.Array `weight:"-"`
|
||||
}
|
||||
|
||||
// Forward projects vision features to text space
|
||||
// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152])
|
||||
// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560])
|
||||
func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array {
|
||||
// Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152]
|
||||
// 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens
|
||||
B := visionFeatures.Shape()[0]
|
||||
visionHidden := visionFeatures.Shape()[2]
|
||||
|
||||
// Reshape to [B, 64, 64, hidden]
|
||||
gridSize := int32(64) // sqrt(4096)
|
||||
pooledSize := int32(16) // 64/4
|
||||
h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden)
|
||||
|
||||
// Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling
|
||||
h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden)
|
||||
|
||||
// Average over pooling dimensions (axes 2 and 4)
|
||||
h = mlx.Mean(h, 4, false)
|
||||
h = mlx.Mean(h, 2, false)
|
||||
|
||||
// h is now [B, 16, 16, hidden], reshape to [B, 256, hidden]
|
||||
numTokens := pooledSize * pooledSize
|
||||
h = mlx.Reshape(h, B, numTokens, visionHidden)
|
||||
|
||||
// Apply Gemma-style RMS norm (use precomputed 1 + weight)
|
||||
h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps)
|
||||
|
||||
// Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden]
|
||||
return mlx.Linear(h, p.InputProjection)
|
||||
}
|
||||
138
x/imagegen/models/gemma3/vision.go
Normal file
138
x/imagegen/models/gemma3/vision.go
Normal file
@@ -0,0 +1,138 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
)
|
||||
|
||||
// VisionConfig holds configuration for the SigLIP vision tower
|
||||
type VisionConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
ImageSize int32 `json:"image_size"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
PatchSize int32 `json:"patch_size"`
|
||||
}
|
||||
|
||||
// VisionTower is the SigLIP vision encoder
|
||||
type VisionTower struct {
|
||||
Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"`
|
||||
Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
|
||||
PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"`
|
||||
Config *VisionConfig
|
||||
}
|
||||
|
||||
// VisionEmbeddings handles patch and position embeddings
|
||||
type VisionEmbeddings struct {
|
||||
// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
|
||||
PatchWeight *mlx.Array `weight:"patch_embedding.weight"`
|
||||
PatchBias *mlx.Array `weight:"patch_embedding.bias"`
|
||||
PosEmbed *nn.Embedding `weight:"position_embedding"`
|
||||
}
|
||||
|
||||
// VisionEncoderLayer is a single transformer encoder layer
|
||||
type VisionEncoderLayer struct {
|
||||
LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"`
|
||||
Attention *VisionAttention `weight:"self_attn"`
|
||||
LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"`
|
||||
MLP *VisionMLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
// VisionAttention implements multi-head self-attention
|
||||
type VisionAttention struct {
|
||||
QProj *nn.Linear `weight:"q_proj"`
|
||||
KProj *nn.Linear `weight:"k_proj"`
|
||||
VProj *nn.Linear `weight:"v_proj"`
|
||||
OutProj *nn.Linear `weight:"out_proj"`
|
||||
}
|
||||
|
||||
// VisionMLP is the feed-forward network
|
||||
type VisionMLP struct {
|
||||
FC1 *nn.Linear `weight:"fc1"`
|
||||
FC2 *nn.Linear `weight:"fc2"`
|
||||
}
|
||||
|
||||
// Forward runs the vision tower on preprocessed images
|
||||
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
|
||||
// Output: [B, num_patches, hidden_size]
|
||||
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
|
||||
// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
|
||||
weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
|
||||
h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding
|
||||
|
||||
// Add bias: [O] -> [1, 1, 1, O] for broadcasting
|
||||
bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
|
||||
h = mlx.Add(h, bias)
|
||||
|
||||
// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
|
||||
B := h.Shape()[0]
|
||||
gridH, gridW := h.Shape()[1], h.Shape()[2]
|
||||
hidden := h.Shape()[3]
|
||||
numPatches := gridH * gridW
|
||||
h = mlx.Reshape(h, B, numPatches, hidden)
|
||||
|
||||
// Add position embeddings
|
||||
posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
|
||||
posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
|
||||
h = mlx.Add(h, posEmbed)
|
||||
|
||||
// Encoder layers
|
||||
headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
|
||||
scale := float32(1.0 / math.Sqrt(float64(headDim)))
|
||||
for _, layer := range v.Encoder {
|
||||
h = layer.Forward(h, v.Config, scale)
|
||||
}
|
||||
|
||||
// Final layer norm
|
||||
h = v.PostLayerNorm.Forward(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Forward runs a vision encoder layer
|
||||
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
||||
// Pre-norm attention
|
||||
h := l.LayerNorm1.Forward(x)
|
||||
h = l.Attention.Forward(h, cfg, scale)
|
||||
x = mlx.Add(x, h)
|
||||
|
||||
// Pre-norm MLP
|
||||
h = l.LayerNorm2.Forward(x)
|
||||
h = l.MLP.Forward(h)
|
||||
return mlx.Add(x, h)
|
||||
}
|
||||
|
||||
// Forward runs multi-head self-attention
|
||||
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
||||
B, L := x.Shape()[0], x.Shape()[1]
|
||||
headDim := cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
|
||||
// Scaled dot-product attention (no causal mask for vision)
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)
|
||||
|
||||
return a.OutProj.Forward(out)
|
||||
}
|
||||
|
||||
// Forward runs the MLP with GELU activation
|
||||
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
h := mlx.GELU(m.FC1.Forward(x))
|
||||
return m.FC2.Forward(h)
|
||||
}
|
||||
840
x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
840
x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
@@ -0,0 +1,840 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
|
||||
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// RopeScaling holds RoPE scaling configuration
|
||||
type RopeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
MscaleAllDim float32 `json:"mscale_all_dim"`
|
||||
}
|
||||
|
||||
// Config holds GLM4-MoE-Lite model configuration
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
AttentionBias bool `json:"attention_bias"`
|
||||
|
||||
// MLA (Multi-head Latent Attention) parameters
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
|
||||
// MoE parameters
|
||||
NRoutedExperts int32 `json:"n_routed_experts"`
|
||||
NSharedExperts int32 `json:"n_shared_experts"`
|
||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
|
||||
NGroup int32 `json:"n_group"`
|
||||
TopKGroup int32 `json:"topk_group"`
|
||||
|
||||
// RoPE scaling
|
||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization)
|
||||
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
||||
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
||||
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
||||
|
||||
// Computed fields
|
||||
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
||||
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
|
||||
}
|
||||
|
||||
// MLAAttention implements Multi-head Latent Attention with absorption.
|
||||
// This uses absorbed MLA which operates in latent space for reduced KV cache.
|
||||
type MLAAttention struct {
|
||||
// Low-rank query projections
|
||||
QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"`
|
||||
QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
|
||||
QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"`
|
||||
|
||||
// Low-rank KV projections (with shared rope component)
|
||||
KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"`
|
||||
KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
|
||||
|
||||
// Absorbed MLA projections (derived from kv_b_proj)
|
||||
// EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
|
||||
EmbedQ *nn.MultiLinear `weight:"-"`
|
||||
UnembedOut *nn.MultiLinear `weight:"-"`
|
||||
|
||||
// Output projection
|
||||
OProj nn.LinearLayer `weight:"self_attn.o_proj"`
|
||||
}
|
||||
|
||||
// Forward computes absorbed MLA attention output.
|
||||
// This operates in latent space for reduced KV cache memory.
|
||||
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Query path: q_a_proj -> layernorm -> q_b_proj
|
||||
q := a.QAProj.Forward(x)
|
||||
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
|
||||
q = a.QBProj.Forward(q)
|
||||
|
||||
// Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
|
||||
// Split Q into nope and rope parts
|
||||
qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
|
||||
qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
|
||||
|
||||
// KV path: get compressed KV and k_pe
|
||||
compressedKV := a.KVAProjWithMQA.Forward(x)
|
||||
|
||||
// Split into compressed_kv and k_pe (shared rope component)
|
||||
kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
|
||||
kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
|
||||
|
||||
// k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
|
||||
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
|
||||
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
|
||||
|
||||
// Apply layernorm to get kv latent representation
|
||||
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
|
||||
// kvLatent: [B, L, kv_lora_rank] -> [B, 1, L, kv_lora_rank] for broadcasting
|
||||
kvLatent = mlx.ExpandDims(kvLatent, 1)
|
||||
|
||||
// Apply RoPE to the rope parts
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
// ABSORBED MLA: project q_nope to latent space
|
||||
// qNope: [B, num_heads, L, qk_nope_head_dim]
|
||||
// EmbedQ: [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// Result: [B, num_heads, L, kv_lora_rank]
|
||||
qLatent := a.EmbedQ.Forward(qNope)
|
||||
|
||||
// Keys = concat(kvLatent, kPE)
|
||||
// kvLatent: [B, 1, L, kv_lora_rank]
|
||||
// kPE: [B, 1, L, qk_rope_head_dim]
|
||||
// keys: [B, 1, L, kv_lora_rank + qk_rope_head_dim]
|
||||
keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
|
||||
|
||||
// Cache the smaller latent representation
|
||||
// We cache keys (latent + rope) and use empty values since values are derived from keys
|
||||
cachedL := L
|
||||
if c != nil {
|
||||
// Create placeholder values with 0 dims for cache (we don't actually use cached values)
|
||||
placeholderValues := mlx.Zeros([]int32{B, 1, L, 0}, mlx.DtypeFloat32)
|
||||
keys, _ = c.Update(keys, placeholderValues, int(L))
|
||||
cachedL = int32(keys.Shape()[2])
|
||||
}
|
||||
|
||||
// Values are the first kv_lora_rank dims of keys (slice off rope part)
|
||||
values := mlx.Slice(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
|
||||
|
||||
// Queries = concat(qLatent, qPE)
|
||||
// qLatent: [B, num_heads, L, kv_lora_rank]
|
||||
// qPE: [B, num_heads, L, qk_rope_head_dim]
|
||||
// queries: [B, num_heads, L, kv_lora_rank + qk_rope_head_dim]
|
||||
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
|
||||
|
||||
// Attention in latent space
|
||||
// queries: [B, num_heads, L, kv_lora_rank + rope_dim]
|
||||
// keys: [B, 1, cachedL, kv_lora_rank + rope_dim]
|
||||
// values: [B, 1, cachedL, kv_lora_rank]
|
||||
out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
|
||||
|
||||
// ABSORBED MLA: unembed from latent space
|
||||
// out: [B, num_heads, L, kv_lora_rank]
|
||||
// UnembedOut: [num_heads, v_head_dim, kv_lora_rank]
|
||||
// Result: [B, num_heads, L, v_head_dim]
|
||||
out = a.UnembedOut.Forward(out)
|
||||
|
||||
// Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
|
||||
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// DenseMLP implements the standard SwiGLU MLP for dense layers
|
||||
type DenseMLP struct {
|
||||
GateProj nn.LinearLayer `weight:"mlp.gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"mlp.up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the SwiGLU MLP
|
||||
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||
up := m.UpProj.Forward(x)
|
||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoEGate implements the expert gating mechanism
|
||||
type MoEGate struct {
|
||||
Gate nn.LinearLayer `weight:"mlp.gate"`
|
||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
||||
}
|
||||
|
||||
// Forward computes expert selection indices and scores
|
||||
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
// Compute gate logits through linear layer (handles both quantized and non-quantized)
|
||||
gates := g.Gate.Forward(x)
|
||||
|
||||
// Sigmoid scoring
|
||||
scores := mlx.Sigmoid(gates)
|
||||
origScores := scores
|
||||
|
||||
// Add correction bias if present
|
||||
if g.EScoreCorrectionBias != nil {
|
||||
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
||||
}
|
||||
|
||||
// Group-wise expert selection (simplified for n_group=1)
|
||||
// Select top-k experts
|
||||
topK := cfg.NumExpertsPerTok
|
||||
negScores := mlx.Neg(scores)
|
||||
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
||||
|
||||
shape := inds.Shape()
|
||||
inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
|
||||
|
||||
// Get scores for selected experts
|
||||
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
||||
|
||||
// Normalize if configured
|
||||
if topK > 1 && cfg.NormTopKProb {
|
||||
sumScores := mlx.Sum(scores, -1, true)
|
||||
scores = mlx.Div(scores, sumScores)
|
||||
}
|
||||
|
||||
// Apply routing scaling factor
|
||||
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
|
||||
|
||||
return inds, scores
|
||||
}
|
||||
|
||||
// SwitchMLP implements the MoE expert computation using stacked weights
|
||||
// Note: No weight tags - these are populated manually by stacking expert weights
|
||||
type SwitchMLP struct {
|
||||
// Dequantized weights (used when GatherQMM not available)
|
||||
GateWeight *mlx.Array
|
||||
UpWeight *mlx.Array
|
||||
DownWeight *mlx.Array
|
||||
|
||||
// Quantized weights (used with GatherQMM for 4/8-bit affine)
|
||||
GateWeightQ, GateScales, GateBiases *mlx.Array
|
||||
UpWeightQ, UpScales, UpBiases *mlx.Array
|
||||
DownWeightQ, DownScales, DownBiases *mlx.Array
|
||||
|
||||
// Quantization bits per projection (supports mixed precision Q4/Q8)
|
||||
GateBits int
|
||||
UpBits int
|
||||
DownBits int
|
||||
|
||||
// Quantization group size per projection (detected from tensor shapes)
|
||||
GateGroupSize int
|
||||
UpGroupSize int
|
||||
DownGroupSize int
|
||||
|
||||
// If true, use GatherQMM with quantized weights
|
||||
UseQuantized bool
|
||||
}
|
||||
|
||||
// Forward applies the switched expert MLP
|
||||
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
topK := cfg.NumExpertsPerTok
|
||||
|
||||
// Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
|
||||
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
|
||||
|
||||
// Flatten for gather_mm: [B*L, 1, 1, D]
|
||||
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
|
||||
|
||||
// Flatten indices: [B, L, topK] -> [B*L, topK]
|
||||
idxFlat := mlx.Reshape(indices, B*L, topK)
|
||||
|
||||
// Sort for efficient gather (when we have many tokens)
|
||||
doSort := B*L >= 64
|
||||
var invOrder *mlx.Array
|
||||
n := B * L * topK
|
||||
|
||||
if doSort {
|
||||
idxAll := mlx.Flatten(idxFlat)
|
||||
order := mlx.Argsort(idxAll, 0)
|
||||
invOrder = mlx.Argsort(order, 0)
|
||||
// Reorder x based on sorted indices
|
||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
|
||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
||||
}
|
||||
|
||||
var gate, up, hidden, down *mlx.Array
|
||||
|
||||
if s.UseQuantized {
|
||||
// Use GatherQMM for quantized weights (faster, keeps weights quantized)
|
||||
// Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down)
|
||||
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
|
||||
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
|
||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||
} else {
|
||||
// Use GatherMM for dequantized/non-quantized weights
|
||||
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
|
||||
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
}
|
||||
|
||||
// Unsort if we sorted
|
||||
if doSort {
|
||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
|
||||
} else {
|
||||
down = mlx.Squeeze(down, 2)
|
||||
}
|
||||
|
||||
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// SharedExperts implements the shared expert MLP
|
||||
type SharedExperts struct {
|
||||
GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the shared expert MLP
|
||||
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(s.GateProj.Forward(x))
|
||||
up := s.UpProj.Forward(x)
|
||||
return s.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoE implements the full Mixture of Experts layer
|
||||
type MoE struct {
|
||||
Gate *MoEGate
|
||||
SwitchMLP *SwitchMLP
|
||||
SharedExperts *SharedExperts
|
||||
}
|
||||
|
||||
// Forward applies the MoE layer
|
||||
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
// Get expert indices and scores
|
||||
inds, scores := m.Gate.Forward(x, cfg)
|
||||
|
||||
// Apply routed experts
|
||||
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
|
||||
|
||||
// Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
|
||||
scoresExpanded := mlx.ExpandDims(scores, -1)
|
||||
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
|
||||
|
||||
// Add shared experts if present
|
||||
if m.SharedExperts != nil {
|
||||
y = mlx.Add(y, m.SharedExperts.Forward(x))
|
||||
}
|
||||
|
||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
|
||||
type DenseBlock struct {
|
||||
Attention *MLAAttention
|
||||
MLP *DenseMLP
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the dense block
|
||||
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Pre-norm attention with residual
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
// Pre-norm MLP with residual
|
||||
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// MoEBlock represents a MoE transformer block
|
||||
type MoEBlock struct {
|
||||
Attention *MLAAttention
|
||||
MoE *MoE
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the MoE block
|
||||
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Pre-norm attention with residual
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
// Pre-norm MoE with residual
|
||||
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// Block interface for both dense and MoE blocks
|
||||
type Block interface {
|
||||
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
||||
}
|
||||
|
||||
// Model represents the complete GLM4-MoE-Lite model
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []Block `weight:"-"` // Loaded manually due to different block types
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
LMHead nn.LinearLayer `weight:"lm_head"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
// computeScale computes the attention scale.
|
||||
// Uses the full key head dimension (qkNopeHeadDim + qkRopeHeadDim) to match the Ollama runner.
|
||||
func computeScale(cfg *Config) float32 {
|
||||
keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
scale := float32(1.0 / math.Sqrt(float64(keyLength)))
|
||||
if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
|
||||
s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
|
||||
scale *= s * s
|
||||
}
|
||||
return scale
|
||||
}
|
||||
|
||||
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
|
||||
// Currently only 4-bit and 8-bit affine quantization are supported.
|
||||
func supportsGatherQMM(mode string, bits int) bool {
|
||||
return mode == "affine" && (bits == 4 || bits == 8)
|
||||
}
|
||||
|
||||
// ExpertWeight holds a single expert's weight with optional quantization components.
|
||||
type ExpertWeight struct {
|
||||
Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight
|
||||
Scales *mlx.Array // Quantization scales (nil if not quantized)
|
||||
Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases)
|
||||
Bits int // Quantization bits (4 or 8), 0 if not quantized
|
||||
GroupSize int // Quantization group size, 0 if not quantized
|
||||
}
|
||||
|
||||
// getQuantParams returns quantization parameters from model metadata.
|
||||
// Returns groupSize, bits, and mode for the model's quantization type.
|
||||
func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) {
|
||||
groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization())
|
||||
// Use metadata group_size if available (overrides default)
|
||||
if gs := weights.GroupSize(); gs > 0 {
|
||||
groupSize = gs
|
||||
}
|
||||
return groupSize, bits, mode
|
||||
}
|
||||
|
||||
// loadExpertWeight loads an expert weight.
|
||||
// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components.
|
||||
// Otherwise dequantizes and returns only the weight.
|
||||
func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight {
|
||||
w, _ := weights.GetTensor(path + ".weight")
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if this is a quantized weight by looking for scales
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
scales, _ := weights.GetTensor(scalePath)
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
// Get quantization params from metadata
|
||||
groupSize, bits, mode := getQuantParams(weights)
|
||||
|
||||
// Update config with group size (for GatherQMM calls)
|
||||
if cfg.QuantGroupSize == 0 {
|
||||
cfg.QuantGroupSize = groupSize
|
||||
}
|
||||
|
||||
// If GatherQMM is supported and requested, return quantized components
|
||||
if useQuantized && supportsGatherQMM(mode, bits) {
|
||||
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
|
||||
}
|
||||
|
||||
// Otherwise dequantize
|
||||
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
|
||||
}
|
||||
|
||||
return &ExpertWeight{Weight: w}
|
||||
}
|
||||
|
||||
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
|
||||
// Returns embed_q and unembed_out weights for per-head projections.
|
||||
//
|
||||
// kv_b_proj.weight shape: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
// Output:
|
||||
// - embed_q: [num_heads, kv_lora_rank, qk_nope_head_dim] - projects q_nope to latent
|
||||
// - unembed_out: [num_heads, v_head_dim, kv_lora_rank] - projects latent to output
|
||||
func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
path := prefix + ".self_attn.kv_b_proj"
|
||||
w, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil || w == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check if quantized and dequantize
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
scales, _ := weights.GetTensor(scalePath)
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
groupSize, bits, mode := getQuantParams(weights)
|
||||
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
// w: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
|
||||
// Reshape to [num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
|
||||
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
|
||||
w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
|
||||
|
||||
// Split into wk and wv
|
||||
// wk: [num_heads, qk_nope_head_dim, kv_lora_rank]
|
||||
// wv: [num_heads, v_head_dim, kv_lora_rank]
|
||||
wk := mlx.Slice(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
|
||||
wv := mlx.Slice(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
|
||||
|
||||
// Transform for absorbed MLA:
|
||||
// embed_q: transpose(wk) -> [num_heads, kv_lora_rank, qk_nope_head_dim]
|
||||
// This allows: q_nope @ embed_q.T = q_nope @ wk (absorbed key projection)
|
||||
embedQ := mlx.Transpose(wk, 0, 2, 1)
|
||||
|
||||
// unembed_out: wv stays [num_heads, v_head_dim, kv_lora_rank]
|
||||
// This allows: latent_out @ unembed_out.T = latent_out @ wv.T (absorbed value projection)
|
||||
unembedOut := wv
|
||||
|
||||
return embedQ, unembedOut
|
||||
}
|
||||
|
||||
// StackedExpertWeights holds stacked weights for all experts.
|
||||
type StackedExpertWeights struct {
|
||||
Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed]
|
||||
Scales *mlx.Array // Stacked scales (nil if not quantized)
|
||||
Biases *mlx.Array // Stacked biases (nil if not quantized)
|
||||
Bits int // Quantization bits (4 or 8), 0 if not quantized
|
||||
GroupSize int // Quantization group size, 0 if not quantized
|
||||
}
|
||||
|
||||
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
|
||||
func collectAndStackExpertWeights(
|
||||
weights safetensors.WeightSource,
|
||||
prefix string,
|
||||
projName string,
|
||||
numExperts int32,
|
||||
useQuantized bool,
|
||||
cfg *Config,
|
||||
) *StackedExpertWeights {
|
||||
var w, s, b []*mlx.Array
|
||||
var bits, groupSize int
|
||||
|
||||
for e := int32(0); e < numExperts; e++ {
|
||||
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
|
||||
ew := loadExpertWeight(weights, path, useQuantized, cfg)
|
||||
if ew == nil {
|
||||
continue
|
||||
}
|
||||
w = append(w, ew.Weight)
|
||||
if ew.Scales != nil {
|
||||
s = append(s, ew.Scales)
|
||||
}
|
||||
if ew.Biases != nil {
|
||||
b = append(b, ew.Biases)
|
||||
}
|
||||
if e == 0 {
|
||||
bits = ew.Bits
|
||||
groupSize = ew.GroupSize
|
||||
}
|
||||
}
|
||||
|
||||
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
|
||||
if len(w) > 0 {
|
||||
result.Weight = mlx.Stack(w, 0)
|
||||
if len(s) > 0 {
|
||||
result.Scales = mlx.Stack(s, 0)
|
||||
}
|
||||
if len(b) > 0 {
|
||||
result.Biases = mlx.Stack(b, 0)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// sanitizeExpertWeights stacks individual expert weights into tensors.
|
||||
// If useQuantized is true and weights support GatherQMM, returns quantized components.
|
||||
// Otherwise returns dequantized weights with nil scales/biases.
|
||||
// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down).
|
||||
func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
|
||||
gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg)
|
||||
up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg)
|
||||
down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg)
|
||||
return gate, up, down
|
||||
}
|
||||
|
||||
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
|
||||
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
// Read config from manifest
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Compute derived fields
|
||||
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
cfg.Scale = computeScale(&cfg)
|
||||
|
||||
// Load weights from manifest blobs
|
||||
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
if err := weights.Load(0); err != nil {
|
||||
return nil, fmt.Errorf("load weight data: %w", err)
|
||||
}
|
||||
|
||||
// Set up quantization parameters (only if model is actually quantized)
|
||||
// Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading
|
||||
quantization := weights.Quantization()
|
||||
useQuantized := false
|
||||
if quantization != "" {
|
||||
_, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization)
|
||||
useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
||||
}
|
||||
|
||||
// Load tokenizer from manifest with config files for EOS token detection
|
||||
tokData, err := modelManifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
// Build tokenizer config with companion files for EOS/BOS token loading
|
||||
tokConfig := &tokenizer.TokenizerConfig{
|
||||
ConfigJSON: configData, // Already loaded above, contains eos_token_id
|
||||
}
|
||||
|
||||
// Try to load generation_config.json if available (preferred source for EOS)
|
||||
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
|
||||
// Try to load tokenizer_config.json if available
|
||||
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]Block, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Load embedding, norm, and lm_head
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load layers manually due to different block types
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
|
||||
// Load attention (same for both block types)
|
||||
attn := &MLAAttention{}
|
||||
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d attention: %w", i, err)
|
||||
}
|
||||
|
||||
// Sanitize MLA weights for absorbed attention
|
||||
embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg)
|
||||
attn.EmbedQ = nn.NewMultiLinear(embedQ)
|
||||
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
|
||||
|
||||
if i < cfg.FirstKDenseReplace {
|
||||
// Dense block
|
||||
block := &DenseBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d dense: %w", i, err)
|
||||
}
|
||||
m.Layers[i] = block
|
||||
} else {
|
||||
// MoE block
|
||||
block := &MoEBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
|
||||
}
|
||||
|
||||
// Stack expert weights (pass cfg so group sizes can be detected)
|
||||
gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
|
||||
|
||||
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
|
||||
if useQuantized {
|
||||
switchMLP.GateWeightQ = gate.Weight
|
||||
switchMLP.GateScales = gate.Scales
|
||||
switchMLP.GateBiases = gate.Biases
|
||||
switchMLP.GateBits = gate.Bits
|
||||
switchMLP.GateGroupSize = gate.GroupSize
|
||||
switchMLP.UpWeightQ = up.Weight
|
||||
switchMLP.UpScales = up.Scales
|
||||
switchMLP.UpBiases = up.Biases
|
||||
switchMLP.UpBits = up.Bits
|
||||
switchMLP.UpGroupSize = up.GroupSize
|
||||
switchMLP.DownWeightQ = down.Weight
|
||||
switchMLP.DownScales = down.Scales
|
||||
switchMLP.DownBiases = down.Biases
|
||||
switchMLP.DownBits = down.Bits
|
||||
switchMLP.DownGroupSize = down.GroupSize
|
||||
} else {
|
||||
switchMLP.GateWeight = gate.Weight
|
||||
switchMLP.UpWeight = up.Weight
|
||||
switchMLP.DownWeight = down.Weight
|
||||
}
|
||||
|
||||
block.MoE = &MoE{
|
||||
Gate: &MoEGate{},
|
||||
SwitchMLP: switchMLP,
|
||||
}
|
||||
|
||||
// Load gate weights
|
||||
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d gate: %w", i, err)
|
||||
}
|
||||
|
||||
// Load shared experts if present
|
||||
if cfg.NSharedExperts > 0 {
|
||||
block.MoE.SharedExperts = &SharedExperts{}
|
||||
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.Layers[i] = block
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward computes the forward pass of the model
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
h = m.Norm.Forward(h, m.RMSNormEps)
|
||||
return m.LMHead.Forward(h)
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
|
||||
// NumLayers returns the number of transformer layers
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// MaxContextLength returns the maximum context length
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
// Tokenizer returns the model's tokenizer
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
|
||||
// NewCache creates a new KV cache for the model
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
|
||||
// This follows the GLM-4.7 format with <think> tag for reasoning mode.
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
||||
}
|
||||
|
||||
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
|
||||
// When think is true, the prompt ends with <think> to enable reasoning mode.
|
||||
// When think is false, the prompt ends with </think> to skip reasoning.
|
||||
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
|
||||
if think {
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
||||
}
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
|
||||
}
|
||||
|
||||
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
|
||||
func (m *Model) NewRenderer() *Renderer {
|
||||
return &Renderer{}
|
||||
}
|
||||
|
||||
// NewParser returns a new Parser for extracting thinking and tool calls from output.
|
||||
func (m *Model) NewParser() *Parser {
|
||||
return &Parser{}
|
||||
}
|
||||
479
x/imagegen/models/glm4_moe_lite/parser.go
Normal file
479
x/imagegen/models/glm4_moe_lite/parser.go
Normal file
@@ -0,0 +1,479 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type parserState int
|
||||
|
||||
const (
|
||||
parserState_LookingForThinkingOpen parserState = iota
|
||||
parserState_ThinkingStartedEatingWhitespace
|
||||
parserState_CollectingThinking
|
||||
parserState_ThinkingDoneEatingWhitespace
|
||||
parserState_CollectingContent
|
||||
parserState_ToolStartedEatingWhitespace
|
||||
parserState_CollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
thinkingOpenTag = "<think>"
|
||||
thinkingCloseTag = "</think>"
|
||||
toolOpenTag = "<tool_call>"
|
||||
toolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
|
||||
// GLM-4's prompt ends with <think> when thinking is enabled, so the parser
|
||||
// must start in CollectingThinking state (the model outputs thinking content directly).
|
||||
type Parser struct {
|
||||
state parserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
// HasToolSupport returns true as GLM4 supports tool calling.
|
||||
func (p *Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// HasThinkingSupport returns true as GLM4 supports thinking mode.
|
||||
func (p *Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Init initializes the parser with tools and thinking configuration.
|
||||
func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
p.state = parserState_CollectingThinking
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
type parserEvent interface {
|
||||
isParserEvent()
|
||||
}
|
||||
|
||||
type eventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (eventContent) isParserEvent() {}
|
||||
|
||||
type eventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (eventRawToolCall) isParserEvent() {}
|
||||
|
||||
type eventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (eventThinkingContent) isParserEvent() {}
|
||||
|
||||
// Add processes new output text and returns parsed content, thinking, and tool calls.
|
||||
func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case eventRawToolCall:
|
||||
toolCall, err := parseToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("glm-4 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case eventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case eventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseEvents() []parserEvent {
|
||||
var all []parserEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []parserEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
|
||||
// and transitions to the next state. Returns (nil, false) if only whitespace remains
|
||||
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
|
||||
func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false // Still only whitespace, keep waiting for more input
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true // Successfully transitioned
|
||||
}
|
||||
|
||||
// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
|
||||
// the content after (optionally trimmed of leading whitespace), and updates the buffer
|
||||
func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after
|
||||
}
|
||||
|
||||
func (p *Parser) eat() ([]parserEvent, bool) {
|
||||
var events []parserEvent
|
||||
|
||||
switch p.state {
|
||||
case parserState_LookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, thinkingOpenTag) {
|
||||
// Found <think> opening tag
|
||||
after := strings.TrimPrefix(trimmed, thinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = parserState_ThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(thinkingOpenTag, trimmed) {
|
||||
// Partial opening tag seen, keep accumulating
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
// Only whitespace, keep accumulating
|
||||
return events, false
|
||||
} else {
|
||||
// No thinking tag found, skip to content collection
|
||||
p.state = parserState_CollectingContent
|
||||
// Don't trim - we want to keep the original content
|
||||
return events, true
|
||||
}
|
||||
|
||||
case parserState_ThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
|
||||
|
||||
case parserState_CollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, thinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, eventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = parserState_ThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
|
||||
// Partial closing tag - withhold it along with any trailing whitespace before it
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case parserState_ThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
|
||||
|
||||
case parserState_CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||
before, after := p.splitAtTag(toolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, eventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = parserState_ToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case parserState_ToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
|
||||
|
||||
case parserState_CollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, toolCloseTag) {
|
||||
toolContent, _ := p.splitAtTag(toolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("glm4 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, eventRawToolCall{raw: toolContent})
|
||||
p.state = parserState_CollectingContent
|
||||
return events, true
|
||||
} else {
|
||||
// Keep accumulating - tool calls are not streamed
|
||||
// We just wait for the closing tag
|
||||
return events, false
|
||||
}
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// overlap returns the length of the overlap between the end of s and the start of tag.
|
||||
func overlap(s, tag string) int {
|
||||
for i := 1; i <= len(tag) && i <= len(s); i++ {
|
||||
if strings.HasSuffix(s, tag[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// trailingWhitespaceLen returns the length of trailing whitespace in s.
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
|
||||
return len(s) - len(trimmed)
|
||||
}
|
||||
|
||||
// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
|
||||
type ToolCallXML struct {
|
||||
XMLName xml.Name `xml:"tool_call"`
|
||||
Content string `xml:",chardata"` // Function name (text nodes between tags)
|
||||
Keys []string `xml:"arg_key"` // All arg_key elements in document order
|
||||
Values []string `xml:"arg_value"` // All arg_value elements in document order
|
||||
}
|
||||
|
||||
// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
|
||||
func escapeContent(s string) string {
|
||||
var result strings.Builder
|
||||
inTag := false
|
||||
|
||||
for i := range len(s) {
|
||||
ch := s[i]
|
||||
|
||||
if ch == '<' {
|
||||
// Check if this is a known tag
|
||||
if strings.HasPrefix(s[i:], "<arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "<arg_value>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_value>") {
|
||||
inTag = true
|
||||
}
|
||||
}
|
||||
|
||||
if inTag {
|
||||
result.WriteByte(ch)
|
||||
if ch == '>' {
|
||||
inTag = false
|
||||
}
|
||||
} else {
|
||||
// Escape special characters in text content
|
||||
switch ch {
|
||||
case '&':
|
||||
result.WriteString("&")
|
||||
case '<':
|
||||
result.WriteString("<")
|
||||
case '>':
|
||||
result.WriteString(">")
|
||||
default:
|
||||
result.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
// Escape any unescaped entities in text content
|
||||
escaped := escapeContent(raw.raw)
|
||||
|
||||
// Wrap the content in a root element to make it valid XML
|
||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
||||
|
||||
// Parse XML into struct
|
||||
var parsed ToolCallXML
|
||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
}
|
||||
|
||||
// Extract and trim function name
|
||||
functionName := strings.TrimSpace(parsed.Content)
|
||||
if functionName == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
// Verify keys and values are paired correctly
|
||||
if len(parsed.Keys) != len(parsed.Values) {
|
||||
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionName {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Build arguments map by pairing keys and values
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: functionName,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range parsed.Keys {
|
||||
key := strings.TrimSpace(parsed.Keys[i])
|
||||
value := parsed.Values[i] // Don't trim here - parseValue handles it
|
||||
|
||||
// Look up parameter type
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
|
||||
// Handle anyOf by collecting all types from the union
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
paramType = append(paramType, anyOfProp.Type...)
|
||||
}
|
||||
} else {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse value with type coercion
|
||||
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
// parseValue parses a string value and coerces it to the appropriate type based on paramType.
|
||||
func parseValue(value string, paramType api.PropertyType) any {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// If no type specified, return as string
|
||||
if len(paramType) == 0 {
|
||||
return value
|
||||
}
|
||||
|
||||
// Try to parse based on specified types
|
||||
for _, t := range paramType {
|
||||
switch t {
|
||||
case "boolean":
|
||||
if value == "true" {
|
||||
return true
|
||||
}
|
||||
if value == "false" {
|
||||
return false
|
||||
}
|
||||
case "integer":
|
||||
var i int64
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
return i
|
||||
}
|
||||
case "number":
|
||||
var f float64
|
||||
if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
|
||||
return f
|
||||
}
|
||||
case "array", "object":
|
||||
// Try to parse as JSON
|
||||
var result any
|
||||
if err := json.Unmarshal([]byte(value), &result); err == nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default to string
|
||||
return value
|
||||
}
|
||||
192
x/imagegen/models/glm4_moe_lite/parser_test.go
Normal file
192
x/imagegen/models/glm4_moe_lite/parser_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestParserThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
thinkEnabled bool
|
||||
wantContent string
|
||||
wantThinking string
|
||||
wantToolCalls int
|
||||
}{
|
||||
{
|
||||
name: "thinking enabled - simple thinking then content",
|
||||
input: "Let me think about this...</think>Here is my answer.",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "Let me think about this...",
|
||||
wantContent: "Here is my answer.",
|
||||
},
|
||||
{
|
||||
name: "thinking enabled - only thinking",
|
||||
input: "I need to consider multiple factors...",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "I need to consider multiple factors...",
|
||||
wantContent: "",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled - direct content",
|
||||
input: "Here is my direct answer.",
|
||||
thinkEnabled: false,
|
||||
wantThinking: "",
|
||||
wantContent: "Here is my direct answer.",
|
||||
},
|
||||
{
|
||||
name: "thinking with tool call",
|
||||
input: "Let me search for that...</think>I'll use a tool.<tool_call>search<arg_key>query</arg_key><arg_value>test</arg_value></tool_call>",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "Let me search for that...",
|
||||
wantContent: "I'll use a tool.",
|
||||
wantToolCalls: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Parser{}
|
||||
|
||||
var thinkValue *api.ThinkValue
|
||||
if tt.thinkEnabled {
|
||||
thinkValue = &api.ThinkValue{Value: true}
|
||||
} else {
|
||||
thinkValue = &api.ThinkValue{Value: false}
|
||||
}
|
||||
|
||||
// Define tools for tool call tests
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "search",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
p.Init(tools, nil, thinkValue)
|
||||
|
||||
content, thinking, calls, err := p.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if thinking != tt.wantThinking {
|
||||
t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
|
||||
}
|
||||
if content != tt.wantContent {
|
||||
t.Errorf("content = %q, want %q", content, tt.wantContent)
|
||||
}
|
||||
if len(calls) != tt.wantToolCalls {
|
||||
t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserToolCall(t *testing.T) {
|
||||
p := &Parser{}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize with thinking disabled
|
||||
tv := &api.ThinkValue{Value: false}
|
||||
p.Init(tools, nil, tv)
|
||||
|
||||
input := "<tool_call>get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call>"
|
||||
|
||||
_, _, calls, err := p.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
|
||||
call := calls[0]
|
||||
if call.Function.Name != "get_weather" {
|
||||
t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
location, ok := call.Function.Arguments.Get("location")
|
||||
if !ok || location != "San Francisco" {
|
||||
t.Errorf("location = %v, want %q", location, "San Francisco")
|
||||
}
|
||||
|
||||
unit, ok := call.Function.Arguments.Get("unit")
|
||||
if !ok || unit != "celsius" {
|
||||
t.Errorf("unit = %v, want %q", unit, "celsius")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverlap(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
tag string
|
||||
want int
|
||||
}{
|
||||
{"hello<", "</think>", 1},
|
||||
{"hello</", "</think>", 2},
|
||||
{"hello</t", "</think>", 3},
|
||||
{"hello</th", "</think>", 4},
|
||||
{"hello</thi", "</think>", 5},
|
||||
{"hello</thin", "</think>", 6},
|
||||
{"hello</think", "</think>", 7},
|
||||
{"hello</think>", "</think>", 8}, // Complete tag at end returns full length
|
||||
{"hello", "</think>", 0},
|
||||
{"", "</think>", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
|
||||
got := overlap(tt.s, tt.tag)
|
||||
if got != tt.want {
|
||||
t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrailingWhitespaceLen(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
want int
|
||||
}{
|
||||
{"hello ", 3},
|
||||
{"hello\n\t ", 3},
|
||||
{"hello", 0},
|
||||
{"", 0},
|
||||
{" ", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s, func(t *testing.T) {
|
||||
got := trailingWhitespaceLen(tt.s)
|
||||
if got != tt.want {
|
||||
t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
175
x/imagegen/models/glm4_moe_lite/render.go
Normal file
175
x/imagegen/models/glm4_moe_lite/render.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Renderer renders messages for GLM4-MoE-Lite models.
|
||||
//
|
||||
// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
|
||||
//
|
||||
// 1. INTERLEAVED THINKING
|
||||
// The model thinks between tool calls and after receiving tool results.
|
||||
// This enables complex step-by-step reasoning: interpreting each tool output
|
||||
// before deciding what to do next. Thinking blocks are preserved and returned
|
||||
// with tool results to maintain reasoning continuity.
|
||||
//
|
||||
// 2. PRESERVED THINKING
|
||||
// The model retains reasoning content from previous assistant turns in context.
|
||||
// This preserves reasoning continuity across multi-turn conversations. The
|
||||
// upstream API has a "clear_thinking" parameter to control this:
|
||||
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
|
||||
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
|
||||
//
|
||||
// 3. TURN-LEVEL THINKING
|
||||
// Controls whether the model should reason on each turn. The upstream API
|
||||
// uses "enable_thinking" parameter:
|
||||
// - enable_thinking=true: outputs <think> to start reasoning
|
||||
// - enable_thinking=false: outputs </think> to skip reasoning
|
||||
//
|
||||
// OLLAMA DEFAULTS:
|
||||
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
|
||||
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
|
||||
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
|
||||
// - Users can disable thinking per-turn via thinkValue=false
|
||||
type Renderer struct{}
|
||||
|
||||
// Render renders messages into the GLM4 chat format.
|
||||
func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(formatToolJSON(d))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
||||
}
|
||||
|
||||
think := true
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
think = false
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>")
|
||||
sb.WriteString(message.Content)
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<tool_call>" + toolCall.Function.Name)
|
||||
sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("<tool_response>")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|assistant|>")
|
||||
if think {
|
||||
sb.WriteString("<think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// renderToolArguments converts tool call arguments to GLM4 XML format.
|
||||
func renderToolArguments(args api.ToolCallFunctionArguments) string {
|
||||
var sb strings.Builder
|
||||
for key, value := range args.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
|
||||
func formatToolJSON(raw []byte) string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(raw) + len(raw)/10)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := range raw {
|
||||
ch := raw[i]
|
||||
sb.WriteByte(ch)
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
inString = true
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ':' || ch == ',' {
|
||||
sb.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
205
x/imagegen/models/glm4_moe_lite/render_test.go
Normal file
205
x/imagegen/models/glm4_moe_lite/render_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestRendererSimple(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
// Thinking enabled (default)
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|><think>"
|
||||
if result != expected {
|
||||
t.Errorf("result = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererThinkingDisabled(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
tv := &api.ThinkValue{Value: false}
|
||||
|
||||
result, err := r.Render(messages, nil, tv)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|></think>"
|
||||
if result != expected {
|
||||
t.Errorf("result = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererMultiTurn(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
{Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
|
||||
{Role: "user", Content: "And 3+3?"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check key parts
|
||||
if !strings.Contains(result, "[gMASK]<sop>") {
|
||||
t.Error("missing [gMASK]<sop> prefix")
|
||||
}
|
||||
if !strings.Contains(result, "<|user|>What is 2+2?") {
|
||||
t.Error("missing first user message")
|
||||
}
|
||||
if !strings.Contains(result, "<|assistant|><think>Let me calculate: 2+2=4</think>4") {
|
||||
t.Error("missing assistant message with thinking")
|
||||
}
|
||||
if !strings.Contains(result, "<|user|>And 3+3?") {
|
||||
t.Error("missing second user message")
|
||||
}
|
||||
if !strings.HasSuffix(result, "<|assistant|><think>") {
|
||||
t.Errorf("should end with <|assistant|><think>, got suffix: %q", result[len(result)-30:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithSystem(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
|
||||
t.Error("missing system message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithTools(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the weather for a location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, tools, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check for tool system prompt
|
||||
if !strings.Contains(result, "<|system|>") {
|
||||
t.Error("missing system tag for tools")
|
||||
}
|
||||
if !strings.Contains(result, "# Tools") {
|
||||
t.Error("missing tools header")
|
||||
}
|
||||
if !strings.Contains(result, "<tools>") {
|
||||
t.Error("missing tools tag")
|
||||
}
|
||||
if !strings.Contains(result, "get_weather") {
|
||||
t.Error("missing tool name")
|
||||
}
|
||||
if !strings.Contains(result, "</tools>") {
|
||||
t.Error("missing closing tools tag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithToolCalls(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("location", "San Francisco")
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What's the weather in SF?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 72F"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "<tool_call>get_weather") {
|
||||
t.Error("missing tool call")
|
||||
}
|
||||
if !strings.Contains(result, "<arg_key>location</arg_key>") {
|
||||
t.Error("missing arg_key")
|
||||
}
|
||||
if !strings.Contains(result, "<arg_value>San Francisco</arg_value>") {
|
||||
t.Error("missing arg_value")
|
||||
}
|
||||
if !strings.Contains(result, "</tool_call>") {
|
||||
t.Error("missing tool call closing tag")
|
||||
}
|
||||
if !strings.Contains(result, "<|observation|>") {
|
||||
t.Error("missing observation tag")
|
||||
}
|
||||
if !strings.Contains(result, "<tool_response>Sunny, 72F</tool_response>") {
|
||||
t.Error("missing tool response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolJSON(t *testing.T) {
|
||||
input := []byte(`{"name":"test","value":123}`)
|
||||
result := formatToolJSON(input)
|
||||
|
||||
// Should add spaces after : and ,
|
||||
if !strings.Contains(result, ": ") {
|
||||
t.Error("should add space after colon")
|
||||
}
|
||||
if !strings.Contains(result, ", ") {
|
||||
t.Error("should add space after comma")
|
||||
}
|
||||
}
|
||||
487
x/imagegen/models/gpt_oss/gpt_oss.go
Normal file
487
x/imagegen/models/gpt_oss/gpt_oss.go
Normal file
@@ -0,0 +1,487 @@
|
||||
//go:build mlx
|
||||
|
||||
package gpt_oss
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// RopeScaling holds YaRN or other RoPE scaling configuration
|
||||
type RopeScaling struct {
|
||||
RopeType string `json:"rope_type"`
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
NumLocalExperts int32 `json:"num_local_experts"`
|
||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
SwiGLULimit float32 `json:"swiglu_limit"`
|
||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
||||
Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim)
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
Sinks *mlx.Array `weight:"self_attn.sinks,optional"`
|
||||
YarnFreqs *mlx.Array // computed
|
||||
YarnMscale float32
|
||||
}
|
||||
|
||||
// swiGLU applies the GPT-OSS custom SwiGLU activation.
|
||||
// Formula: (gate * sigmoid(alpha * gate)) * (up + 1)
|
||||
// with clipping: gate to [None, limit], up to [-limit, limit]
|
||||
func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array {
|
||||
// Clip gate to [None, limit]
|
||||
gateClipped := mlx.ClipScalar(gate, 0, limit, false, true)
|
||||
|
||||
// Clip up to [-limit, limit]
|
||||
upClipped := mlx.ClipScalar(up, -limit, limit, true, true)
|
||||
|
||||
// glu_scaled = alpha * gate_clipped
|
||||
gluScaled := mlx.MulScalar(gateClipped, alpha)
|
||||
|
||||
// sig = sigmoid(glu_scaled)
|
||||
sig := mlx.Sigmoid(gluScaled)
|
||||
|
||||
// out_glu = gate_clipped * sig
|
||||
outGlu := mlx.Mul(gateClipped, sig)
|
||||
|
||||
// result = out_glu * (up_clipped + 1)
|
||||
return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0))
|
||||
}
|
||||
|
||||
// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers
|
||||
var compiledSwiGLU *mlx.CompiledFunc
|
||||
|
||||
// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed
|
||||
func getCompiledSwiGLU() *mlx.CompiledFunc {
|
||||
if compiledSwiGLU == nil {
|
||||
const alpha float32 = 1.702
|
||||
const limit float32 = 7.0
|
||||
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
|
||||
}, true) // shapeless=true so it works for any input size
|
||||
}
|
||||
return compiledSwiGLU
|
||||
}
|
||||
|
||||
// ComputeYarnFreqs computes YaRN-modified RoPE frequencies
|
||||
// Based on mlx-lm's YarnRoPE implementation
|
||||
func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) {
|
||||
// yarn_find_correction_dim
|
||||
yarnFindCorrectionDim := func(numRotations float64) float64 {
|
||||
return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base)))
|
||||
}
|
||||
|
||||
// yarn_find_correction_range
|
||||
low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast))))
|
||||
high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow))))
|
||||
if low < 0 {
|
||||
low = 0
|
||||
}
|
||||
if high > int(dims)-1 {
|
||||
high = int(dims) - 1
|
||||
}
|
||||
|
||||
// yarn_get_mscale
|
||||
yarnGetMscale := func(scale, mscale float64) float64 {
|
||||
if scale <= 1 {
|
||||
return 1.0
|
||||
}
|
||||
return 0.1*mscale*math.Log(scale) + 1.0
|
||||
}
|
||||
mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0))
|
||||
|
||||
// Compute frequencies
|
||||
// freq_extra = base ** (arange(0, dims, 2) / dims)
|
||||
// freq_inter = scaling_factor * freq_extra
|
||||
halfDims := dims / 2
|
||||
freqData := make([]float32, halfDims)
|
||||
for i := int32(0); i < halfDims; i++ {
|
||||
exp := float64(2*i) / float64(dims)
|
||||
freqExtra := math.Pow(float64(base), exp)
|
||||
freqInter := float64(scalingFactor) * freqExtra
|
||||
|
||||
// linear ramp mask
|
||||
var freqMask float64
|
||||
if low == high {
|
||||
freqMask = 0.0
|
||||
} else {
|
||||
t := (float64(i) - float64(low)) / float64(high-low)
|
||||
if t < 0 {
|
||||
t = 0
|
||||
}
|
||||
if t > 1 {
|
||||
t = 1
|
||||
}
|
||||
freqMask = 1.0 - t
|
||||
}
|
||||
|
||||
// Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask))
|
||||
freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask)))
|
||||
}
|
||||
|
||||
return mlx.NewArray(freqData, []int32{halfDims}), mscale
|
||||
}
|
||||
|
||||
// initYarn initializes YaRN RoPE if configured
|
||||
func (a *Attention) initYarn(cfg *Config) {
|
||||
a.YarnMscale = 1.0
|
||||
if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" {
|
||||
a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs(
|
||||
cfg.HeadDim,
|
||||
cfg.RopeTheta,
|
||||
cfg.RopeScaling.Factor,
|
||||
cfg.RopeScaling.OriginalMaxPositionEmbeddings,
|
||||
cfg.RopeScaling.BetaFast,
|
||||
cfg.RopeScaling.BetaSlow,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
if a.YarnFreqs != nil {
|
||||
if a.YarnMscale != 1.0 {
|
||||
q = mlx.MulScalar(q, a.YarnMscale)
|
||||
}
|
||||
q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
||||
k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
||||
} else {
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v, int(L))
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// CreateSlidingWindowMask creates a causal mask with sliding window
|
||||
// Mirrors mlx-lm's create_causal_mask with window_size
|
||||
func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array {
|
||||
// Build mask aligned to actual cache length (may be rotated)
|
||||
// rinds covers existing keys: [keyStart, keyStart+keyLen)
|
||||
// linds covers new queries: [queryStart, queryStart+seqLen)
|
||||
rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen]
|
||||
linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen]
|
||||
|
||||
linds = mlx.ExpandDims(linds, 1) // [seqLen, 1]
|
||||
rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen]
|
||||
|
||||
causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen]
|
||||
windowLimit := mlx.AddScalar(rinds, float32(windowSize))
|
||||
windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen]
|
||||
|
||||
return mlx.LogicalAnd(causalMask, windowMask)
|
||||
}
|
||||
|
||||
// MoE represents the Mixture of Experts SwiGLU layer with quantized experts.
|
||||
type MoE struct {
|
||||
Router *nn.Linear `weight:"mlp.router"`
|
||||
TopK int32
|
||||
HiddenSize int32
|
||||
GroupSize int
|
||||
Bits int
|
||||
// Expert weights (loaded manually via sanitizeExpertWeights)
|
||||
GateBlocks, GateScales, GateBias *mlx.Array
|
||||
UpBlocks, UpScales, UpBias *mlx.Array
|
||||
DownBlocks, DownScales, DownBias *mlx.Array
|
||||
}
|
||||
|
||||
func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array {
|
||||
logits := moe.Router.Forward(x)
|
||||
neg := mlx.Neg(logits)
|
||||
part := mlx.Argpartition(neg, int(moe.TopK)-1, -1)
|
||||
topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK})
|
||||
topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1)
|
||||
weights := mlx.Softmax(topKVal, -1)
|
||||
|
||||
xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize)
|
||||
idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK)
|
||||
|
||||
doSort := B*L >= 64
|
||||
var invOrder *mlx.Array
|
||||
sorted := false
|
||||
n := B * L * moe.TopK
|
||||
|
||||
if doSort {
|
||||
idxAll := mlx.Flatten(idxFlat)
|
||||
order := mlx.Argsort(idxAll, 0)
|
||||
invOrder = mlx.Argsort(order, 0)
|
||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1)
|
||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
||||
sorted = true
|
||||
}
|
||||
|
||||
gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
|
||||
if moe.GateBias != nil {
|
||||
gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2))
|
||||
}
|
||||
if moe.UpBias != nil {
|
||||
up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2))
|
||||
}
|
||||
|
||||
hidden := getCompiledSwiGLU().Call(gate, up)[0]
|
||||
|
||||
down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
if moe.DownBias != nil {
|
||||
down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2))
|
||||
}
|
||||
|
||||
if doSort {
|
||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize)
|
||||
} else {
|
||||
down = mlx.Squeeze(down, 2)
|
||||
}
|
||||
|
||||
ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1)
|
||||
return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize)
|
||||
}
|
||||
|
||||
type Block struct {
|
||||
Attention *Attention
|
||||
MLP *MoE
|
||||
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
LayerType string // "sliding_attention" or "full_attention"
|
||||
}
|
||||
|
||||
func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg))
|
||||
return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L))
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
LMHead *nn.Linear `weight:"lm_head"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
func (m *Model) NewCache(int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i, layer := range m.Layers {
|
||||
if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 {
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
||||
} else {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
x := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
// Find representative cache indices for sliding window attention
|
||||
var swaIdx int = -1
|
||||
for i, layer := range m.Layers {
|
||||
if layer.LayerType == "sliding_attention" {
|
||||
swaIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Create masks once at model level
|
||||
var fullMask, swaMask *mlx.Array
|
||||
var fullMaskMode, swaMaskMode string
|
||||
|
||||
if L > 1 {
|
||||
fullMaskMode = "causal"
|
||||
if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil {
|
||||
c := caches[swaIdx]
|
||||
offset := c.Offset()
|
||||
windowSize := int(m.SlidingWindow)
|
||||
cacheLen := min(int(L), windowSize)
|
||||
if offset > 0 {
|
||||
cacheLen = min(c.Len()+int(L), windowSize)
|
||||
}
|
||||
if int(L) > windowSize {
|
||||
swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize)
|
||||
} else {
|
||||
swaMaskMode = "causal"
|
||||
}
|
||||
} else {
|
||||
swaMaskMode = "causal"
|
||||
}
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
mask, maskMode := fullMask, fullMaskMode
|
||||
if layer.LayerType == "sliding_attention" {
|
||||
mask, maskMode = swaMask, swaMaskMode
|
||||
}
|
||||
x = layer.Forward(x, c, B, L, mask, maskMode, m.Config)
|
||||
}
|
||||
|
||||
return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps))
|
||||
}
|
||||
|
||||
// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays.
|
||||
// MXFP4 quantized weights require contiguous memory - strided views give wrong results.
|
||||
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) {
|
||||
gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks")
|
||||
gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales")
|
||||
gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias")
|
||||
downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks")
|
||||
downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales")
|
||||
downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias")
|
||||
|
||||
moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias}
|
||||
|
||||
if gateUpBlocks != nil {
|
||||
gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1)
|
||||
s := gub.Shape()
|
||||
moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
}
|
||||
if gateUpScales != nil {
|
||||
s := gateUpScales.Shape()
|
||||
moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
}
|
||||
if gateUpBias != nil {
|
||||
s := gateUpBias.Shape()
|
||||
moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2}))
|
||||
moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2}))
|
||||
}
|
||||
if downBlocks != nil {
|
||||
moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1)
|
||||
}
|
||||
return moe
|
||||
}
|
||||
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Block, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Load simple weights via struct tags
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load layers with custom MoE handling
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
layer := &Block{}
|
||||
if err := safetensors.LoadModule(layer, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d: %w", i, err)
|
||||
}
|
||||
|
||||
// Initialize attention YaRN
|
||||
layer.Attention.initYarn(&cfg)
|
||||
|
||||
// Load MoE with weight sanitization
|
||||
moe := sanitizeExpertWeights(weights, prefix)
|
||||
moe.Router = layer.MLP.Router // Router was loaded by LoadModule
|
||||
moe.TopK = cfg.NumExpertsPerTok
|
||||
moe.HiddenSize = cfg.HiddenSize
|
||||
layer.MLP = moe
|
||||
|
||||
// Set layer type
|
||||
layer.LayerType = "full_attention"
|
||||
if int(i) < len(cfg.LayerTypes) {
|
||||
layer.LayerType = cfg.LayerTypes[i]
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
// Release safetensors BEFORE eval - lazy arrays have captured data,
|
||||
// this reduces peak memory by freeing mmap during materialization
|
||||
weights.ReleaseAll()
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int32 {
|
||||
if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 {
|
||||
return m.RopeScaling.OriginalMaxPositionEmbeddings
|
||||
}
|
||||
return 131072
|
||||
}
|
||||
152
x/imagegen/models/llama/llama.go
Normal file
152
x/imagegen/models/llama/llama.go
Normal file
@@ -0,0 +1,152 @@
|
||||
//go:build mlx
|
||||
|
||||
package llama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
HeadDim int32 `json:"-"`
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Layer `weight:"model.layers"`
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
Output *nn.Linear `weight:"lm_head,optional"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
Attention *Attention
|
||||
MLP *MLP
|
||||
AttentionNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
MLPNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
GateProj *nn.Linear `weight:"mlp.gate_proj"`
|
||||
UpProj *nn.Linear `weight:"mlp.up_proj"`
|
||||
DownProj *nn.Linear `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.Config)
|
||||
}
|
||||
return m.Output.Forward(m.Norm.Forward(h, m.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
|
||||
k, v = c.Update(k, v, int(L))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
@@ -39,23 +39,19 @@ func Execute(args []string) error {
|
||||
return fmt.Errorf("--port is required")
|
||||
}
|
||||
|
||||
// Detect model type from capabilities
|
||||
mode := detectModelMode(*modelName)
|
||||
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
|
||||
|
||||
if mode != ModeImageGen {
|
||||
return fmt.Errorf("imagegen runner only supports image generation models")
|
||||
}
|
||||
|
||||
// Initialize MLX only for image generation mode.
|
||||
// Initialize MLX
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
slog.Error("unable to initialize MLX", "error", err)
|
||||
return err
|
||||
}
|
||||
slog.Info("MLX library initialized")
|
||||
|
||||
// Detect model type from capabilities
|
||||
mode := detectModelMode(*modelName)
|
||||
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
|
||||
|
||||
// Create and start server
|
||||
server, err := newServer(*modelName, *port)
|
||||
server, err := newServer(*modelName, *port, mode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create server: %w", err)
|
||||
}
|
||||
@@ -65,6 +61,12 @@ func Execute(args []string) error {
|
||||
mux.HandleFunc("/health", server.healthHandler)
|
||||
mux.HandleFunc("/completion", server.completionHandler)
|
||||
|
||||
// LLM-specific endpoints
|
||||
if mode == ModeLLM {
|
||||
mux.HandleFunc("/tokenize", server.tokenizeHandler)
|
||||
mux.HandleFunc("/embedding", server.embeddingHandler)
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
||||
Handler: mux,
|
||||
@@ -110,22 +112,34 @@ func detectModelMode(modelName string) ModelMode {
|
||||
|
||||
// server holds the model and handles HTTP requests.
|
||||
type server struct {
|
||||
mode ModelMode
|
||||
modelName string
|
||||
port int
|
||||
|
||||
// Image generation model.
|
||||
// Image generation model (when mode == ModeImageGen)
|
||||
imageModel ImageModel
|
||||
|
||||
// LLM model (when mode == ModeLLM)
|
||||
llmModel *llmState
|
||||
}
|
||||
|
||||
// newServer creates a new server instance for image generation models.
|
||||
func newServer(modelName string, port int) (*server, error) {
|
||||
// newServer creates a new server instance and loads the appropriate model.
|
||||
func newServer(modelName string, port int, mode ModelMode) (*server, error) {
|
||||
s := &server{
|
||||
mode: mode,
|
||||
modelName: modelName,
|
||||
port: port,
|
||||
}
|
||||
|
||||
if err := s.loadImageModel(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load image model: %w", err)
|
||||
switch mode {
|
||||
case ModeImageGen:
|
||||
if err := s.loadImageModel(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load image model: %w", err)
|
||||
}
|
||||
case ModeLLM:
|
||||
if err := s.loadLLMModel(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load LLM model: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
@@ -149,5 +163,41 @@ func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
s.handleImageCompletion(w, r, req)
|
||||
switch s.mode {
|
||||
case ModeImageGen:
|
||||
s.handleImageCompletion(w, r, req)
|
||||
case ModeLLM:
|
||||
s.handleLLMCompletion(w, r, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if s.llmModel == nil {
|
||||
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tok := s.llmModel.model.Tokenizer()
|
||||
tokens := tok.Encode(req.Content, false)
|
||||
|
||||
// Convert int32 to int for JSON response
|
||||
intTokens := make([]int, len(tokens))
|
||||
for i, t := range tokens {
|
||||
intTokens[i] = int(t)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
|
||||
}
|
||||
|
||||
func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
@@ -30,12 +30,13 @@ import (
|
||||
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
|
||||
//
|
||||
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
||||
// like any other model. It is used for image generation models.
|
||||
// like any other model. It supports both LLM (safetensors) and image generation models.
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
mode ModelMode
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
@@ -44,7 +45,7 @@ type Server struct {
|
||||
}
|
||||
|
||||
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
|
||||
func NewServer(modelName string) (*Server, error) {
|
||||
func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
if err := CheckPlatformSupport(); err != nil {
|
||||
return nil, err
|
||||
@@ -118,6 +119,7 @@ func NewServer(modelName string) (*Server, error) {
|
||||
cmd: cmd,
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
mode: mode,
|
||||
vramSize: vramSize,
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
@@ -143,7 +145,7 @@ func NewServer(modelName string) (*Server, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port)
|
||||
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
|
||||
}
|
||||
@@ -394,7 +396,36 @@ func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, e
|
||||
|
||||
// Tokenize tokenizes the input content.
|
||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return nil, errors.New("tokenization not supported for image generation models")
|
||||
body, err := json.Marshal(map[string]string{"content": content})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.Tokens, nil
|
||||
}
|
||||
|
||||
// Detokenize converts tokens back to text.
|
||||
|
||||
@@ -30,80 +30,21 @@ type cacheSession struct {
|
||||
remaining []int32
|
||||
}
|
||||
|
||||
func appendCacheState(dst []*mlx.Array, c cache.Cache) []*mlx.Array {
|
||||
if c == nil {
|
||||
return dst
|
||||
}
|
||||
|
||||
keys, values := c.State()
|
||||
if keys != nil && keys.Valid() {
|
||||
dst = append(dst, keys)
|
||||
}
|
||||
if values != nil && values.Valid() {
|
||||
dst = append(dst, values)
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (c *kvCache) free() {
|
||||
for i, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
kv.Free()
|
||||
c.caches[i] = nil
|
||||
}
|
||||
c.caches = nil
|
||||
c.tokens = nil
|
||||
}
|
||||
|
||||
func (c *kvCache) cachesCanTrim() bool {
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if !kv.CanTrim() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *kvCache) trimToPrefix(prefix int) {
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil || !kv.CanTrim() {
|
||||
continue
|
||||
}
|
||||
if trim := kv.Offset() - prefix; trim > 0 {
|
||||
kv.Trim(trim)
|
||||
}
|
||||
}
|
||||
if prefix < len(c.tokens) {
|
||||
c.tokens = c.tokens[:prefix]
|
||||
}
|
||||
}
|
||||
|
||||
// begin prepares caches for a new request. It finds the nearest
|
||||
// matching cache or creates new caches if none match.
|
||||
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
ensureCaches := func() {
|
||||
if len(c.caches) != 0 {
|
||||
return
|
||||
}
|
||||
if len(c.caches) == 0 {
|
||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
c.caches = cacheFactory.NewCaches()
|
||||
return
|
||||
}
|
||||
c.caches = make([]cache.Cache, m.NumLayers())
|
||||
for i := range c.caches {
|
||||
c.caches[i] = cache.NewKVCache()
|
||||
} else {
|
||||
c.caches = make([]cache.Cache, m.NumLayers())
|
||||
for i := range c.caches {
|
||||
c.caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
ensureCaches()
|
||||
|
||||
remaining := c.findRemaining(inputs)
|
||||
ensureCaches()
|
||||
|
||||
return &cacheSession{
|
||||
cache: c,
|
||||
@@ -115,36 +56,18 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
|
||||
// close saves the token state if the forward pass ran.
|
||||
func (s *cacheSession) close() {
|
||||
if len(s.caches) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
offset := -1
|
||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||
for _, kv := range s.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
if offset := s.caches[0].Offset(); offset > 0 {
|
||||
// Ensure that if we have run the forward pass and set the metadata
|
||||
// that we also actually have the data
|
||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||
for _, c := range s.caches {
|
||||
k, v := c.State()
|
||||
arrays = append(arrays, k, v)
|
||||
}
|
||||
// Mixed cache types (e.g. recurrent + KV) can transiently report different
|
||||
// offsets, so use the minimum as the safe reusable token prefix.
|
||||
if off := kv.Offset(); offset < 0 || off < offset {
|
||||
offset = off
|
||||
}
|
||||
arrays = appendCacheState(arrays, kv)
|
||||
}
|
||||
if offset <= 0 {
|
||||
return
|
||||
}
|
||||
mlx.AsyncEval(arrays...)
|
||||
|
||||
// Ensure that if we have run the forward pass and set the metadata
|
||||
// that we also actually have the data.
|
||||
mlx.AsyncEval(arrays...)
|
||||
|
||||
stored := append(s.inputs, s.outputs...)
|
||||
if offset > len(stored) {
|
||||
offset = len(stored)
|
||||
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
|
||||
}
|
||||
s.cache.tokens = stored[:offset]
|
||||
}
|
||||
|
||||
// findRemaining finds the longest common prefix between tokens and the cached
|
||||
@@ -155,20 +78,12 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
||||
prefix++
|
||||
}
|
||||
|
||||
// Always keep at least one token to re-evaluate so the
|
||||
// pipeline can seed token generation from it.
|
||||
if prefix == len(tokens) && prefix > 0 {
|
||||
prefix--
|
||||
}
|
||||
|
||||
if prefix < len(c.tokens) {
|
||||
if c.cachesCanTrim() {
|
||||
c.trimToPrefix(prefix)
|
||||
} else {
|
||||
c.free()
|
||||
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
|
||||
return tokens
|
||||
trim := len(c.tokens) - prefix
|
||||
for _, kv := range c.caches {
|
||||
kv.Trim(trim)
|
||||
}
|
||||
c.tokens = c.tokens[:prefix]
|
||||
}
|
||||
|
||||
if prefix == 0 {
|
||||
@@ -183,21 +98,10 @@ func (c *kvCache) log() {
|
||||
if len(c.caches) == 0 {
|
||||
return
|
||||
}
|
||||
offset := -1
|
||||
var totalBytes int
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if off := kv.Offset(); offset < 0 || off < offset {
|
||||
offset = off
|
||||
}
|
||||
for _, a := range appendCacheState(nil, kv) {
|
||||
totalBytes += a.NumBytes()
|
||||
}
|
||||
k, v := kv.State()
|
||||
totalBytes += k.NumBytes() + v.NumBytes()
|
||||
}
|
||||
if offset < 0 {
|
||||
return
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
}
|
||||
|
||||
18
x/mlxrunner/cache/cache.go
vendored
18
x/mlxrunner/cache/cache.go
vendored
@@ -9,9 +9,7 @@ import (
|
||||
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
// State returns the cache-owned state roots that should be kept/evaluated.
|
||||
State() (keys, values *mlx.Array)
|
||||
CanTrim() bool
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Free()
|
||||
@@ -62,15 +60,13 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
}
|
||||
|
||||
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return nil, nil
|
||||
if c.offset == c.keys.Dim(2) {
|
||||
return c.keys, c.values
|
||||
}
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) CanTrim() bool { return true }
|
||||
|
||||
func (c *KVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
@@ -187,15 +183,13 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return nil, nil
|
||||
if c.offset < c.keys.Dim(2) {
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) CanTrim() bool { return true }
|
||||
|
||||
func (c *RotatingKVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
|
||||
161
x/mlxrunner/cache/recurrent.go
vendored
161
x/mlxrunner/cache/recurrent.go
vendored
@@ -1,161 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
|
||||
// RecurrentCache stores state for linear-recurrent layers.
|
||||
//
|
||||
// Conv state shape: [B, convTail, convDim]
|
||||
// Delta state shape: [B, numVHeads, headVDim, headKDim]
|
||||
type RecurrentCache struct {
|
||||
convState *mlx.Array
|
||||
deltaState *mlx.Array
|
||||
offset int
|
||||
|
||||
convTail int
|
||||
convDim int
|
||||
numVHeads int
|
||||
headVDim int
|
||||
headKDim int
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array {
|
||||
if v == nil || !v.Valid() {
|
||||
return old
|
||||
}
|
||||
if old == v {
|
||||
return old
|
||||
}
|
||||
|
||||
mlx.Pin(v)
|
||||
if old != nil && old != v {
|
||||
mlx.Unpin(old)
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bool) *mlx.Array {
|
||||
if v == nil || !v.Valid() {
|
||||
return old
|
||||
}
|
||||
if old == v {
|
||||
return old
|
||||
}
|
||||
|
||||
root := v
|
||||
if ensureContiguous {
|
||||
root = mlx.Contiguous(v, false)
|
||||
}
|
||||
detached := root.Clone()
|
||||
|
||||
mlx.Pin(detached)
|
||||
if old != nil && old != detached {
|
||||
mlx.Unpin(old)
|
||||
}
|
||||
|
||||
return detached
|
||||
}
|
||||
|
||||
func snapshotPinned(a *mlx.Array) *mlx.Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return nil
|
||||
}
|
||||
snap := mlx.Copy(a)
|
||||
mlx.Eval(snap)
|
||||
mlx.Pin(snap)
|
||||
return snap
|
||||
}
|
||||
|
||||
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||
return &RecurrentCache{
|
||||
convTail: int(convTail),
|
||||
convDim: int(convDim),
|
||||
numVHeads: int(numVHeads),
|
||||
headVDim: int(headVDim),
|
||||
headKDim: int(headKDim),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
|
||||
if batch <= 0 {
|
||||
batch = 1
|
||||
}
|
||||
|
||||
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
|
||||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
|
||||
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype ||
|
||||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
|
||||
if !needConv && !needDelta {
|
||||
return
|
||||
}
|
||||
|
||||
if needConv {
|
||||
c.convState = c.setStateRaw(c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
|
||||
}
|
||||
if needDelta {
|
||||
c.deltaState = c.setStateRaw(c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.convState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
|
||||
c.convState = c.setStateDetached(c.convState, v, true)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.deltaState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
|
||||
c.deltaState = c.setStateDetached(c.deltaState, v, false)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Advance(n int) {
|
||||
c.offset += n
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return keys, values
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
|
||||
return c.convState, c.deltaState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) CanTrim() bool { return false }
|
||||
|
||||
func (c *RecurrentCache) Trim(n int) int {
|
||||
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
|
||||
_ = n
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Clone() Cache {
|
||||
clone := &RecurrentCache{
|
||||
offset: c.offset,
|
||||
convTail: c.convTail,
|
||||
convDim: c.convDim,
|
||||
numVHeads: c.numVHeads,
|
||||
headVDim: c.headVDim,
|
||||
headKDim: c.headKDim,
|
||||
convState: snapshotPinned(c.convState),
|
||||
deltaState: snapshotPinned(c.deltaState),
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Free() {
|
||||
mlx.Unpin(c.convState, c.deltaState)
|
||||
c.convState, c.deltaState = nil, nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||
func (c *RecurrentCache) Len() int { return c.offset }
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -18,10 +19,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
@@ -29,16 +28,15 @@ import (
|
||||
|
||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||
type Client struct {
|
||||
port int
|
||||
modelName string
|
||||
contextLength atomic.Int64
|
||||
memory atomic.Uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
memory uint
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
||||
@@ -193,19 +191,6 @@ type completionOpts struct {
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
Done bool
|
||||
DoneReason int
|
||||
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
|
||||
Error *api.StatusError
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
@@ -265,24 +250,28 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
var raw CompletionResponse
|
||||
var raw struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||
continue
|
||||
}
|
||||
|
||||
if raw.Error != nil {
|
||||
return *raw.Error
|
||||
}
|
||||
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||
PromptEvalCount: raw.PromptEvalCount,
|
||||
PromptEvalDuration: raw.PromptEvalDuration,
|
||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: raw.EvalDuration,
|
||||
EvalDuration: time.Duration(raw.EvalDuration),
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
@@ -295,7 +284,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
return int(c.contextLength.Load())
|
||||
return math.MaxInt
|
||||
}
|
||||
|
||||
// Detokenize implements llm.LlamaServer.
|
||||
@@ -349,10 +338,9 @@ func (c *Client) Pid() int {
|
||||
}
|
||||
|
||||
type statusResponse struct {
|
||||
Status int
|
||||
Progress int
|
||||
ContextLength int
|
||||
Memory uint64
|
||||
Status int
|
||||
Progress int
|
||||
Memory uint
|
||||
}
|
||||
|
||||
// Ping implements llm.LlamaServer.
|
||||
@@ -375,10 +363,7 @@ func (c *Client) Ping(ctx context.Context) error {
|
||||
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.contextLength.Store(int64(status.ContextLength))
|
||||
c.memory.Store(status.Memory)
|
||||
|
||||
c.memory = status.Memory
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -411,7 +396,7 @@ func (c *Client) currentMemory() uint64 {
|
||||
if err := c.Ping(ctx); err != nil {
|
||||
slog.Warn("failed to get current memory", "error", err)
|
||||
}
|
||||
return c.memory.Load()
|
||||
return uint64(c.memory)
|
||||
}
|
||||
|
||||
// MemorySize implements llm.LlamaServer.
|
||||
|
||||
@@ -7,6 +7,4 @@ import (
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
name string
|
||||
pinned int
|
||||
pinned bool
|
||||
}
|
||||
|
||||
var arrays []*Array
|
||||
@@ -129,7 +129,7 @@ func (t *Array) Clone() *Array {
|
||||
func Pin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned++
|
||||
t.pinned = true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -138,7 +138,7 @@ func Pin(s ...*Array) {
|
||||
func Unpin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned--
|
||||
t.pinned = false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -148,7 +148,7 @@ func Unpin(s ...*Array) {
|
||||
func Sweep() {
|
||||
n := 0
|
||||
for _, t := range arrays {
|
||||
if t.pinned > 0 && t.Valid() {
|
||||
if t.pinned && t.Valid() {
|
||||
arrays[n] = t
|
||||
n++
|
||||
} else if t.Valid() {
|
||||
@@ -175,7 +175,7 @@ func (t *Array) String() string {
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{
|
||||
slog.String("name", t.name),
|
||||
slog.Int("pinned", t.pinned),
|
||||
slog.Bool("pinned", t.pinned),
|
||||
}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
|
||||
@@ -1,370 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
gatedDeltaMetalKernelOnce sync.Once
|
||||
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||
gatedDeltaMetalDisabled bool
|
||||
)
|
||||
|
||||
const gatedDeltaMetalKernelSource = `
|
||||
auto n = thread_position_in_grid.z;
|
||||
auto b_idx = n / Hv;
|
||||
auto hv_idx = n % Hv;
|
||||
auto hk_idx = hv_idx / (Hv / Hk);
|
||||
constexpr int n_per_t = Dk / 32;
|
||||
|
||||
// q, k: [B, T, Hk, Dk]
|
||||
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
|
||||
// v, y: [B, T, Hv, Dv]
|
||||
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
y += b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
|
||||
auto dk_idx = thread_position_in_threadgroup.x;
|
||||
auto dv_idx = thread_position_in_grid.y;
|
||||
|
||||
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||
|
||||
float state[n_per_t];
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = static_cast<float>(i_state[s_idx]);
|
||||
}
|
||||
|
||||
// g: [B, T, Hv]
|
||||
auto g_ = g + b_idx * T * Hv;
|
||||
auto beta_ = beta + b_idx * T * Hv;
|
||||
|
||||
for (int t = 0; t < T; ++t) {
|
||||
float kv_mem = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] * g_[hv_idx];
|
||||
kv_mem += state[i] * k_[s_idx];
|
||||
}
|
||||
kv_mem = simd_sum(kv_mem);
|
||||
|
||||
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
|
||||
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] + k_[s_idx] * delta;
|
||||
out += state[i] * q_[s_idx];
|
||||
}
|
||||
out = simd_sum(out);
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
y[dv_idx] = static_cast<InT>(out);
|
||||
}
|
||||
|
||||
q_ += Hk * Dk;
|
||||
k_ += Hk * Dk;
|
||||
v_ += Hv * Dv;
|
||||
y += Hv * Dv;
|
||||
g_ += Hv;
|
||||
beta_ += Hv;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
o_state[s_idx] = static_cast<InT>(state[i]);
|
||||
}
|
||||
`
|
||||
|
||||
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||
vec := C.mlx_vector_string_new()
|
||||
ok := true
|
||||
for _, s := range values {
|
||||
cs := C.CString(s)
|
||||
if C.mlx_vector_string_append_value(vec, cs) != 0 {
|
||||
ok = false
|
||||
}
|
||||
C.free(unsafe.Pointer(cs))
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
cleanup := func() {
|
||||
C.mlx_vector_string_free(vec)
|
||||
}
|
||||
return vec, cleanup, ok
|
||||
}
|
||||
|
||||
func initGatedDeltaMetalKernel() {
|
||||
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled = true
|
||||
freeInputs()
|
||||
return
|
||||
}
|
||||
defer freeInputs()
|
||||
|
||||
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled = true
|
||||
freeOutputs()
|
||||
return
|
||||
}
|
||||
defer freeOutputs()
|
||||
|
||||
cName := C.CString("gated_delta_step")
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
cSource := C.CString(gatedDeltaMetalKernelSource)
|
||||
defer C.free(unsafe.Pointer(cSource))
|
||||
cHeader := C.CString("")
|
||||
defer C.free(unsafe.Pointer(cHeader))
|
||||
|
||||
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
|
||||
cName,
|
||||
inputs,
|
||||
outputs,
|
||||
cSource,
|
||||
cHeader,
|
||||
C.bool(true),
|
||||
C.bool(false),
|
||||
)
|
||||
}
|
||||
|
||||
// gatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
|
||||
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
|
||||
func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||
if gatedDeltaMetalDisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
Hv, Dv := vd[2], vd[3]
|
||||
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
dtype := q.DType()
|
||||
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
|
||||
if gatedDeltaMetalDisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
cfg := C.mlx_fast_metal_kernel_config_new()
|
||||
defer C.mlx_fast_metal_kernel_config_free(cfg)
|
||||
|
||||
cInT := C.CString("InT")
|
||||
defer C.free(unsafe.Pointer(cInT))
|
||||
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
for _, tpl := range []struct {
|
||||
name string
|
||||
value int
|
||||
}{
|
||||
{name: "Dk", value: Dk},
|
||||
{name: "Dv", value: Dv},
|
||||
{name: "Hk", value: Hk},
|
||||
{name: "Hv", value: Hv},
|
||||
} {
|
||||
cn := C.CString(tpl.name)
|
||||
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||
C.free(unsafe.Pointer(cn))
|
||||
if rc != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
threadY := Dv
|
||||
if threadY > 4 {
|
||||
threadY = 4
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
tScalar := FromValue(T)
|
||||
inputs := []C.mlx_array{
|
||||
q.ctx,
|
||||
k.ctx,
|
||||
v.ctx,
|
||||
g.ctx,
|
||||
beta.ctx,
|
||||
state.ctx,
|
||||
tScalar.ctx,
|
||||
}
|
||||
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
y = New("GATED_DELTA_METAL_Y")
|
||||
nextState = New("GATED_DELTA_METAL_STATE")
|
||||
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||
return y, nextState, true
|
||||
}
|
||||
|
||||
func repeatHeadsForGatedDelta(x *Array, repeatFactor int) *Array {
|
||||
if repeatFactor <= 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Dims()
|
||||
x = ExpandDims(x, 3)
|
||||
x = Tile(x, []int32{1, 1, 1, int32(repeatFactor), 1})
|
||||
return Reshape(x, int32(shape[0]), int32(shape[1]), int32(shape[2]*repeatFactor), int32(shape[3]))
|
||||
}
|
||||
|
||||
func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := int32(qd[0]), int32(qd[1]), int32(qd[2]), int32(qd[3])
|
||||
Hv, Dv := int32(vd[2]), int32(vd[3])
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if kd[0] != int(B) || kd[1] != int(T) || kd[2] != int(Hk) || kd[3] != int(Dk) {
|
||||
return nil, nil
|
||||
}
|
||||
if vd[0] != int(B) || vd[1] != int(T) {
|
||||
return nil, nil
|
||||
}
|
||||
if gd[0] != int(B) || gd[1] != int(T) || gd[2] != int(Hv) {
|
||||
return nil, nil
|
||||
}
|
||||
if bd[0] != int(B) || bd[1] != int(T) || bd[2] != int(Hv) {
|
||||
return nil, nil
|
||||
}
|
||||
if sd[0] != int(B) || sd[1] != int(Hv) || sd[2] != int(Dv) || sd[3] != int(Dk) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
repeatFactor := int(Hv / Hk)
|
||||
q = repeatHeadsForGatedDelta(q, repeatFactor)
|
||||
k = repeatHeadsForGatedDelta(k, repeatFactor)
|
||||
|
||||
nextState = state
|
||||
if T == 1 {
|
||||
qt := Squeeze(q, 1)
|
||||
kt := Squeeze(k, 1)
|
||||
vt := Squeeze(v, 1)
|
||||
gt := Squeeze(g, 1)
|
||||
bt := Squeeze(beta, 1)
|
||||
|
||||
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
|
||||
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
|
||||
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
|
||||
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
|
||||
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
|
||||
return ExpandDims(yt, 1), nextState
|
||||
}
|
||||
|
||||
outs := make([]*Array, 0, T)
|
||||
for t := int32(0); t < T; t++ {
|
||||
qt := Squeeze(SliceStartStop(q, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
|
||||
kt := Squeeze(SliceStartStop(k, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
|
||||
vt := Squeeze(SliceStartStop(v, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dv}), 1)
|
||||
gt := Squeeze(SliceStartStop(g, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
|
||||
bt := Squeeze(SliceStartStop(beta, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
|
||||
|
||||
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
|
||||
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
|
||||
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
|
||||
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
|
||||
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
|
||||
outs = append(outs, ExpandDims(yt, 1))
|
||||
}
|
||||
return Concatenate(outs, 1), nextState
|
||||
}
|
||||
|
||||
// GatedDelta runs the recurrent update operation.
|
||||
//
|
||||
// It uses the fused Metal kernel when available and otherwise falls back to a
|
||||
// backend-agnostic MLX implementation with identical inputs/outputs.
|
||||
func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
||||
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
|
||||
return y, nextState
|
||||
}
|
||||
y, nextState = gatedDeltaFallback(q, k, v, g, beta, state)
|
||||
if y == nil || nextState == nil {
|
||||
panic("mlx.GatedDelta: fallback failed (invalid inputs or unsupported shapes)")
|
||||
}
|
||||
return y, nextState
|
||||
}
|
||||
@@ -64,10 +64,6 @@ func PeakMemory() int {
|
||||
return int(peak)
|
||||
}
|
||||
|
||||
func ResetPeakMemory() {
|
||||
C.mlx_reset_peak_memory()
|
||||
}
|
||||
|
||||
type Memory struct{}
|
||||
|
||||
func (Memory) LogValue() slog.Value {
|
||||
|
||||
@@ -19,7 +19,7 @@ func doEval(outputs []*Array, async bool) {
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
if output != nil && output.Valid() {
|
||||
if output.Valid() {
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,35 +113,6 @@ func Where(condition, a, b *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
|
||||
out := New("CONV1D")
|
||||
C.mlx_conv1d(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
weight.ctx,
|
||||
C.int(stride),
|
||||
C.int(padding),
|
||||
C.int(dilation),
|
||||
C.int(groups),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
if bias != nil && bias.Valid() {
|
||||
out = Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||
out := New("CONTIGUOUS")
|
||||
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||
groups := int32(x.Dim(x.NumDims() - 1))
|
||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||
}
|
||||
|
||||
// Convenience wrappers (function-style for the model code)
|
||||
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
@@ -300,24 +271,6 @@ func Sigmoid(a *Array) *Array {
|
||||
return a.Sigmoid()
|
||||
}
|
||||
|
||||
func Exp(a *Array) *Array {
|
||||
out := New("EXP")
|
||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG")
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
||||
out := New("SOFTMAX_AXIS")
|
||||
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
||||
mask := New("")
|
||||
sinks := New("")
|
||||
@@ -335,11 +288,7 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
||||
|
||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
var w C.mlx_array
|
||||
if weight != nil {
|
||||
w = weight.ctx
|
||||
}
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -429,15 +378,6 @@ func Collect(v any) []*Array {
|
||||
return arrays
|
||||
}
|
||||
|
||||
func Copy(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("COPY")
|
||||
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
|
||||
@@ -20,7 +20,6 @@ type Model interface {
|
||||
Unembed(x *mlx.Array) *mlx.Array
|
||||
NumLayers() int
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
MaxContextLength() int
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||
|
||||
@@ -6,36 +6,18 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func prefillChunkSize() int {
|
||||
return 2 << 10
|
||||
}
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
if r.Model == nil {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
}
|
||||
if enableCompile {
|
||||
mlx.EnableCompile()
|
||||
} else {
|
||||
mlx.DisableCompile()
|
||||
}
|
||||
mlx.ResetPeakMemory()
|
||||
ctx := request.Ctx
|
||||
var (
|
||||
sample, logprobs *mlx.Array
|
||||
nextSample, nextLogprobs *mlx.Array
|
||||
@@ -51,57 +33,42 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
mlx.LogArrays()
|
||||
r.cache.log()
|
||||
}
|
||||
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
||||
}()
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
if len(inputs) == 0 {
|
||||
return errors.New("empty prompt")
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
}
|
||||
|
||||
if len(inputs) >= r.contextLength {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
||||
}
|
||||
}
|
||||
|
||||
// Cap generation to stay within the model's context length
|
||||
maxGenerate := r.contextLength - len(inputs)
|
||||
if request.Options.MaxTokens <= 0 {
|
||||
request.Options.MaxTokens = maxGenerate
|
||||
if enableCompile {
|
||||
mlx.EnableCompile()
|
||||
} else {
|
||||
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
||||
mlx.DisableCompile()
|
||||
}
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
session := r.cache.begin(r.Model, inputs)
|
||||
defer session.close()
|
||||
|
||||
caches := session.caches
|
||||
tokens := session.remaining
|
||||
prefillChunk := prefillChunkSize()
|
||||
|
||||
materializeCaches := func() {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
state = appendCacheState(state, c)
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return
|
||||
}
|
||||
mlx.Eval(state...)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
if err := ctx.Err(); err != nil {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n := min(prefillChunk, total-processed-1)
|
||||
n := min(2<<10, total-processed-1)
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
mlx.Sweep()
|
||||
materializeCaches()
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
s[2*i], s[2*i+1] = c.State()
|
||||
}
|
||||
return s
|
||||
}()...)
|
||||
processed += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
mlx.ClearCache()
|
||||
@@ -126,17 +93,19 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||
now := time.Now()
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
for i := range request.Options.MaxTokens {
|
||||
if err := ctx.Err(); err != nil {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nextSample, nextLogprobs = step(sample)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
mlx.Eval(sample)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
@@ -144,16 +113,18 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
session.outputs = append(session.outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.EvalCount = i
|
||||
final.CompletionTokens = i
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- CompletionResponse{
|
||||
Content: r.Decode(output, &b),
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case request.Responses <- Response{
|
||||
Text: r.Decode(output, &b),
|
||||
Token: int(output),
|
||||
}:
|
||||
}
|
||||
|
||||
@@ -166,10 +137,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
final.EvalDuration = time.Since(now)
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,15 +4,14 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
@@ -22,7 +21,7 @@ import (
|
||||
|
||||
type Request struct {
|
||||
TextCompletionsRequest
|
||||
Responses chan CompletionResponse
|
||||
Responses chan Response
|
||||
Pipeline func(Request) error
|
||||
|
||||
Ctx context.Context
|
||||
@@ -44,12 +43,25 @@ type TextCompletionsRequest struct {
|
||||
} `json:"options"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Text string `json:"content,omitempty"`
|
||||
Token int `json:"token,omitempty"`
|
||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
||||
Done bool `json:"done,omitempty"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
|
||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
CompletionTokens int `json:"eval_count,omitempty"`
|
||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache kvCache
|
||||
contextLength int
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache kvCache
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
@@ -78,7 +90,6 @@ func (r *Runner) Load(modelName string) error {
|
||||
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
r.contextLength = m.MaxContextLength()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -147,17 +158,6 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
case request := <-r.Requests:
|
||||
if err := request.Pipeline(request); err != nil {
|
||||
slog.Info("Request terminated", "error", err)
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
statusErr = api.StatusError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
select {
|
||||
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
||||
case <-request.Ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
close(request.Responses)
|
||||
|
||||
@@ -51,10 +51,9 @@ func Execute(args []string) error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewEncoder(w).Encode(statusResponse{
|
||||
Status: 0,
|
||||
Progress: 100,
|
||||
ContextLength: runner.contextLength,
|
||||
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
|
||||
Status: 0,
|
||||
Progress: 100,
|
||||
Memory: uint(mlx.ActiveMemory() + mlx.CacheMemory()),
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -80,7 +79,7 @@ func Execute(args []string) error {
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
request := Request{Responses: make(chan CompletionResponse)}
|
||||
request := Request{Responses: make(chan Response)}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||
slog.Error("Failed to decode request", "error", err)
|
||||
@@ -89,6 +88,9 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||
if request.Options.MaxTokens < 1 {
|
||||
request.Options.MaxTokens = 16 << 10
|
||||
}
|
||||
|
||||
request.Pipeline = runner.TextGenerationPipeline
|
||||
request.Sampler = sample.New(
|
||||
|
||||
@@ -430,10 +430,6 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// MaxContextLength returns the maximum context length
|
||||
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
@@ -262,10 +262,6 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
@@ -15,40 +15,6 @@ type LinearLayer interface {
|
||||
OutputDim() int32
|
||||
}
|
||||
|
||||
// Conv1d applies 1D convolution over NLC input.
|
||||
type Conv1d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Stride int32
|
||||
Padding int32
|
||||
Dilation int32
|
||||
Groups int32
|
||||
}
|
||||
|
||||
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
|
||||
if stride <= 0 {
|
||||
stride = 1
|
||||
}
|
||||
if dilation <= 0 {
|
||||
dilation = 1
|
||||
}
|
||||
if groups <= 0 {
|
||||
groups = 1
|
||||
}
|
||||
return &Conv1d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
Stride: stride,
|
||||
Padding: padding,
|
||||
Dilation: dilation,
|
||||
Groups: groups,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
|
||||
}
|
||||
|
||||
// Linear applies an affine transformation: y = x @ W.T + b
|
||||
type Linear struct {
|
||||
Weight *mlx.Array
|
||||
|
||||
@@ -279,10 +279,6 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,159 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func TestParseConfigNestedDefaults(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model_type": "Qwen3_5MoeForConditionalGeneration",
|
||||
"text_config": {
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"num_hidden_layers": 8,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"linear_num_value_heads": 64,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4,
|
||||
"num_experts": 16,
|
||||
"num_experts_per_tok": 4,
|
||||
"moe_intermediate_size": 2048,
|
||||
"shared_expert_intermediate_size": 4096,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 500000,
|
||||
"partial_rotary_factor": 0.5
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.RopeTheta != 500000 {
|
||||
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
|
||||
}
|
||||
if cfg.RopeDim != 64 {
|
||||
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||
}
|
||||
if cfg.FullAttentionInterval != 4 {
|
||||
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
|
||||
}
|
||||
if !cfg.NormTopKProb {
|
||||
t.Fatalf("norm_topk_prob should default to true for MoE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerSelectionHelpers(t *testing.T) {
|
||||
cfg := &Config{
|
||||
NumHiddenLayers: 6,
|
||||
FullAttentionInterval: 3,
|
||||
NumExperts: 8,
|
||||
DecoderSparseStep: 2,
|
||||
MLPOnlyLayers: []int32{1},
|
||||
}
|
||||
|
||||
if !layerIsLinear(cfg, 0) {
|
||||
t.Fatalf("layer 0 should be linear")
|
||||
}
|
||||
if layerIsLinear(cfg, 2) {
|
||||
t.Fatalf("layer 2 should be full attention")
|
||||
}
|
||||
|
||||
if layerUsesMoE(cfg, 1) {
|
||||
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
|
||||
}
|
||||
if !layerUsesMoE(cfg, 3) {
|
||||
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTensorPathLayout(t *testing.T) {
|
||||
dummy := mlx.New("dummy")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
wantContainer string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "standard",
|
||||
key: "model.embed_tokens.weight",
|
||||
wantContainer: "",
|
||||
wantModel: "model.",
|
||||
},
|
||||
{
|
||||
name: "nested language model with inner model",
|
||||
key: "model.language_model.model.embed_tokens.weight",
|
||||
wantContainer: "model.language_model.",
|
||||
wantModel: "model.",
|
||||
},
|
||||
{
|
||||
name: "nested language model without inner model",
|
||||
key: "model.language_model.embed_tokens.weight",
|
||||
wantContainer: "model.language_model.",
|
||||
wantModel: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
layout := resolveTensorPathLayout(map[string]*mlx.Array{
|
||||
tt.key: dummy,
|
||||
})
|
||||
|
||||
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
|
||||
t.Fatalf(
|
||||
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
|
||||
layout.containerPrefix,
|
||||
layout.modelPrefix,
|
||||
tt.wantContainer,
|
||||
tt.wantModel,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesLayout(t *testing.T) {
|
||||
m := &Model{
|
||||
Config: &Config{
|
||||
LinearConvKernelDim: 4,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearKeyHeadDim: 8,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 16,
|
||||
},
|
||||
Layers: []*Layer{
|
||||
{IsLinear: true},
|
||||
{IsLinear: false},
|
||||
{IsLinear: true},
|
||||
},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if len(caches) != len(m.Layers) {
|
||||
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
|
||||
}
|
||||
|
||||
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
|
||||
}
|
||||
if _, ok := caches[1].(*cache.KVCache); !ok {
|
||||
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
|
||||
}
|
||||
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
|
||||
package qwen3_5_moe
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
|
||||
}
|
||||
Reference in New Issue
Block a user