Compare commits

...

16 Commits

Author SHA1 Message Date
Parth Sareen
465d124183 cmd: fix opencode config (#13894) 2026-01-24 18:42:56 -08:00
Parth Sareen
d310e56fa3 cmd: add fallback for claude (#13892) 2026-01-24 18:26:01 -08:00
Jeffrey Morgan
a1ca428c90 glm4moelite: fix attention scale calculation (#13893)
Use the original key dimension (qkNopeHeadDim + qkRopeHeadDim = 256) for
the attention scale instead of the MLA absorbed dimension (kvLoraRank +
qkRopeHeadDim = 576).

MLA absorption is a mathematically equivalent reorganization of the
attention computation - it should not change the effective attention
scale. The scale should match training, which uses 1/sqrt(256).

This improves tool calling and model looping issues.
2026-01-24 17:48:09 -08:00
Jeffrey Morgan
16750865d1 glm4moelite: quantize more tensors to q8_0 and avoid double BOS token (#13891) 2026-01-24 16:33:54 -08:00
Jeffrey Morgan
f3b476c592 build: add -O3 optimization to CGO flags (#13877)
CGO_CFLAGS and CGO_CXXFLAGS were being set without optimization flags,
which overrides Go's default -O2 and results in unoptimized C++ code.

This caused significant performance degradation in release builds
compared to local `go build` which uses the default optimization.

- build_darwin.sh: add -O3 to CGO_CFLAGS and CGO_CXXFLAGS exports
- Dockerfile: preserve CGO_CFLAGS/CGO_CXXFLAGS from build args instead
  of overwriting them
- app/README.md: update documentation to include -O3
2026-01-24 10:55:38 -08:00
Parth Sareen
5267d31d56 docs: ollama launch (#13852) 2026-01-23 23:18:50 -08:00
Stillhart
b44f56319f README: Update the "Ollama for ruby" to the most popular and maintained ruby gem. (#13855)
* update README ruby link

the ollama-ai ruby gem is vastly less popular and seems unmaintained
https://rubygems.org/gems/ollama-ai

the defacto standard with the most downloads in the ruby ecosystem is ruby_llm
https://rubygems.org/gems/ruby_llm

I would link to that to avoid complication and guarantee feature compatibility with ollama.

* Update gem link ruby_llm from website to GitHub

ollama links mostly to github, not project websites, hence link to ruby_llm github.
2026-01-24 01:24:52 -05:00
Jeffrey Morgan
0209c268bb llama: fix CUDA MMA errors in release build (#13874) 2026-01-23 20:10:04 -08:00
Jeffrey Morgan
912d984346 llama: fix fattn-tile shared memory overflow on sm_50/52 (#13872)
Use nthreads=128 for ncols=4 configurations in flash attention tile
kernel to reduce shared memory usage below 48KB limit on Maxwell
architectures (sm_50/52).

With nthreads=256 and ncols=4, np=2 which caused shared memory to
exceed 48KB. With nthreads=128 and ncols=4, np=1 keeps shared memory
under the limit.
2026-01-23 19:22:32 -08:00
Parth Sareen
aae6ecbaff cmd: rename ollama config to ollama launch (#13871) 2026-01-23 18:40:40 -08:00
Jeffrey Morgan
64737330a4 Re-apply "model: add MLA absorption for glm4moelite" with fix (#13870)
The nvidia_fp32 config for (576, 512) head sizes had nbatch_fa=32,
which caused zero-sized arrays when computing array dimensions:
  nbatch_fa / (np * warp_size) = 32 / (2 * 32) = 0

This resulted in CUDA compilation failures on CUDA 12 (Windows and
Linux arm64):
- "static assertion failed with nbatch_fa % (np*warp_size) != 0"
- "the size of an array must be greater than zero"

Fix by changing nbatch_fa from 32 to 64 for all (576, 512) configs
in the nvidia_fp32 function, matching the nvidia_fp16 and AMD configs.
2026-01-23 18:40:28 -08:00
Jeffrey Morgan
2eda97f1c3 Revert "model: add MLA absorption for glm4moelite (#13810)" (#13869)
This reverts commit 1044b0419a.
2026-01-23 17:14:15 -08:00
Jeffrey Morgan
66831dcf70 x/imagegen: fix image editing support (#13866)
- Fix panic in ollama show for image gen models (safe type assertion)
- Add vision capability for Flux2KleinPipeline models at create time
- Flatten transparent PNG images onto white background for better results
2026-01-23 15:37:17 -08:00
Jeffrey Morgan
1044b0419a model: add MLA absorption for glm4moelite (#13810)
* model: add MLA absorption for glm4moelite

Split the combined KV_B tensor into separate K_B and V_B tensors
during conversion, enabling MLA (Multi-head Latent Attention)
absorption which compresses the KV cache for improved efficiency.

* ggml: enable MLA flash attention for GLM-4.7-flash

Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
uses head size 576 with gqa_ratio 4, which was previously only supported
for gqa_ratio 16 (DeepSeek).

Metal changes:
- Enable head size 576 for flash attention
- Increase simdgroups to 8 for large heads (>=512)
- Add case 8 kernel dispatch for 8 simdgroups

CUDA changes:
- Add gqa_ratio 4 support for head 576/512
- Add tile configs for (576, 512, 4) and (576, 512, 8)
- Add MMA config cases for ncols 4
- Add template instances for ncols2=4

* model: add compatibility validation for glm4moelite architecture
2026-01-23 14:47:42 -08:00
Parth Sareen
771d9280ec cmd: ollama config fix droid model name configuration (#13856) 2026-01-23 11:44:22 -08:00
Jeffrey Morgan
862bc0a3bf x/imagegen: respect stream=false in /api/generate (#13853)
When stream=false is set for image generation requests, return a single
JSON response instead of streaming multiple ndjson progress updates.
2026-01-22 22:16:39 -08:00
48 changed files with 2201 additions and 210 deletions

View File

@@ -169,8 +169,10 @@ COPY . .
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .

View File

@@ -558,7 +558,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LiteLLM](https://github.com/BerriAI/litellm)
- [OllamaFarm for Go](https://github.com/presbrey/ollamafarm)
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
- [Ollama for Ruby](https://github.com/crmne/ruby_llm)
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j)

View File

@@ -75,9 +75,9 @@ The `-dev` flag enables:
CI builds with Xcode 14.1 for OS compatibility prior to v13. If you want to manually build v11+ support, you can download the older Xcode [here](https://developer.apple.com/services-account/download?path=/Developer_Tools/Xcode_14.1/Xcode_14.1.xip), extract, then `mv ./Xcode.app /Applications/Xcode_14.1.0.app` then activate with:
```
export CGO_CFLAGS=-mmacosx-version-min=12.0
export CGO_CXXFLAGS=-mmacosx-version-min=12.0
export CGO_LDFLAGS=-mmacosx-version-min=12.0
export CGO_CFLAGS="-O3 -mmacosx-version-min=12.0"
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=12.0"
export CGO_LDFLAGS="-mmacosx-version-min=12.0"
export SDKROOT=/Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
export DEVELOPER_DIR=/Applications/Xcode_14.1.0.app/Contents/Developer
```

View File

@@ -1019,8 +1019,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
}
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
arch, _ := resp.ModelInfo["general.architecture"].(string)
if arch != "" {
rows = append(rows, []string{"", "architecture", arch})
}
var paramStr string
if resp.Details.ParameterSize != "" {
@@ -1030,7 +1032,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
paramStr = format.HumanNumber(uint64(f))
}
}
rows = append(rows, []string{"", "parameters", paramStr})
if paramStr != "" {
rows = append(rows, []string{"", "parameters", paramStr})
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
@@ -2027,7 +2031,7 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
config.ConfigCmd(checkServerHeartbeat),
config.LaunchCmd(checkServerHeartbeat),
)
return rootCmd

View File

@@ -4,6 +4,8 @@ import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
)
// Claude implements Runner for Claude Code integration
@@ -18,12 +20,32 @@ func (c *Claude) args(model string) []string {
return nil
}
func (c *Claude) findPath() (string, error) {
if p, err := exec.LookPath("claude"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(home, ".claude", "local", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Claude) Run(model string) error {
if _, err := exec.LookPath("claude"); err != nil {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command("claude", c.args(model)...)
cmd := exec.Command(claudePath, c.args(model)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

View File

@@ -1,6 +1,9 @@
package config
import (
"os"
"path/filepath"
"runtime"
"slices"
"testing"
)
@@ -19,6 +22,62 @@ func TestClaudeIntegration(t *testing.T) {
})
}
func TestClaudeFindPath(t *testing.T) {
c := &Claude{}
t.Run("finds claude in PATH", func(t *testing.T) {
tmpDir := t.TempDir()
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fakeBin := filepath.Join(tmpDir, name)
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
t.Setenv("PATH", tmpDir)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fakeBin {
t.Errorf("findPath() = %q, want %q", got, fakeBin)
}
})
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
name := "claude"
if runtime.GOOS == "windows" {
name = "claude.exe"
}
fallback := filepath.Join(tmpDir, ".claude", "local", name)
os.MkdirAll(filepath.Dir(fallback), 0o755)
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fallback {
t.Errorf("findPath() = %q, want %q", got, fallback)
}
})
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
}
func TestClaudeArgs(t *testing.T) {
c := &Claude{}

View File

@@ -7,14 +7,23 @@ import (
"os/exec"
"path/filepath"
"slices"
"strings"
)
// Droid implements Runner and Editor for Droid integration
type Droid struct{}
// droidModelEntry represents a custom model entry in Droid's settings.json
type droidModelEntry struct {
// droidSettings represents the Droid settings.json file (only fields we use)
type droidSettings struct {
CustomModels []modelEntry `json:"customModels"`
SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
}
type sessionSettings struct {
Model string `json:"model"`
ReasoningEffort string `json:"reasoningEffort"`
}
type modelEntry struct {
Model string `json:"model"`
DisplayName string `json:"displayName"`
BaseURL string `json:"baseUrl"`
@@ -76,24 +85,36 @@ func (d *Droid) Edit(models []string) error {
return err
}
settings := make(map[string]any)
// Read file once, unmarshal twice:
// map preserves unknown fields for writing back (including extra fields in model entries)
settingsMap := make(map[string]any)
var settings droidSettings
if data, err := os.ReadFile(settingsPath); err == nil {
if err := json.Unmarshal(data, &settings); err != nil {
if err := json.Unmarshal(data, &settingsMap); err != nil {
return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
}
json.Unmarshal(data, &settings) // ignore error, zero values are fine
}
// Keep only non-Ollama models from the raw map (preserves extra fields)
// Rebuild Ollama models
var nonOllamaModels []any
if rawModels, ok := settingsMap["customModels"].([]any); ok {
for _, raw := range rawModels {
if m, ok := raw.(map[string]any); ok {
if m["apiKey"] != "ollama" {
nonOllamaModels = append(nonOllamaModels, raw)
}
}
}
}
customModels, _ := settings["customModels"].([]any)
// Keep only non-Ollama models (we'll rebuild Ollama models fresh)
nonOllamaModels := slices.DeleteFunc(slices.Clone(customModels), isOllamaModelEntry)
// Build new Ollama model entries with sequential indices (0, 1, 2, ...)
var ollamaModels []any
var newModels []any
var defaultModelID string
for i, model := range models {
modelID := fmt.Sprintf("custom:%s-[Ollama]-%d", model, i)
ollamaModels = append(ollamaModels, droidModelEntry{
modelID := fmt.Sprintf("custom:%s-%d", model, i)
newModels = append(newModels, modelEntry{
Model: model,
DisplayName: model,
BaseURL: "http://localhost:11434/v1",
@@ -109,21 +130,22 @@ func (d *Droid) Edit(models []string) error {
}
}
settings["customModels"] = append(ollamaModels, nonOllamaModels...)
settingsMap["customModels"] = append(newModels, nonOllamaModels...)
sessionSettings, ok := settings["sessionDefaultSettings"].(map[string]any)
// Update session default settings (preserve unknown fields in the nested object)
sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
if !ok {
sessionSettings = make(map[string]any)
}
sessionSettings["model"] = defaultModelID
if effort, ok := sessionSettings["reasoningEffort"].(string); !ok || !isValidReasoningEffort(effort) {
if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
sessionSettings["reasoningEffort"] = "none"
}
settings["sessionDefaultSettings"] = sessionSettings
settingsMap["sessionDefaultSettings"] = sessionSettings
data, err := json.MarshalIndent(settings, "", " ")
data, err := json.MarshalIndent(settingsMap, "", " ")
if err != nil {
return err
}
@@ -135,24 +157,21 @@ func (d *Droid) Models() []string {
if err != nil {
return nil
}
settings, err := readJSONFile(filepath.Join(home, ".factory", "settings.json"))
data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
if err != nil {
return nil
}
customModels, _ := settings["customModels"].([]any)
var settings droidSettings
if err := json.Unmarshal(data, &settings); err != nil {
return nil
}
var result []string
for _, m := range customModels {
if !isOllamaModelEntry(m) {
continue
}
entry, ok := m.(map[string]any)
if !ok {
continue
}
if model, _ := entry["model"].(string); model != "" {
result = append(result, model)
for _, m := range settings.CustomModels {
if m.APIKey == "ollama" {
result = append(result, m.Model)
}
}
return result
@@ -163,12 +182,3 @@ var validReasoningEfforts = []string{"high", "medium", "low", "none"}
func isValidReasoningEffort(effort string) bool {
return slices.Contains(validReasoningEfforts, effort)
}
func isOllamaModelEntry(m any) bool {
entry, ok := m.(map[string]any)
if !ok {
return false
}
id, _ := entry["id"].(string)
return strings.Contains(id, "-[Ollama]-")
}

View File

@@ -2,6 +2,7 @@ package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
@@ -75,8 +76,8 @@ func TestDroidEdit(t *testing.T) {
if models[0]["model"] != "model-a" {
t.Errorf("expected model-a, got %s", models[0]["model"])
}
if models[0]["id"] != "custom:model-a-[Ollama]-0" {
t.Errorf("expected custom:model-a-[Ollama]-0, got %s", models[0]["id"])
if models[0]["id"] != "custom:model-a-0" {
t.Errorf("expected custom:model-a-0, got %s", models[0]["id"])
}
if models[0]["index"] != float64(0) {
t.Errorf("expected index 0, got %v", models[0]["index"])
@@ -86,8 +87,8 @@ func TestDroidEdit(t *testing.T) {
if models[1]["model"] != "model-b" {
t.Errorf("expected model-b, got %s", models[1]["model"])
}
if models[1]["id"] != "custom:model-b-[Ollama]-1" {
t.Errorf("expected custom:model-b-[Ollama]-1, got %s", models[1]["id"])
if models[1]["id"] != "custom:model-b-1" {
t.Errorf("expected custom:model-b-1, got %s", models[1]["id"])
}
if models[1]["index"] != float64(1) {
t.Errorf("expected index 1, got %v", models[1]["index"])
@@ -105,8 +106,8 @@ func TestDroidEdit(t *testing.T) {
if !ok {
t.Fatal("sessionDefaultSettings not found")
}
if session["model"] != "custom:model-a-[Ollama]-0" {
t.Errorf("expected custom:model-a-[Ollama]-0, got %s", session["model"])
if session["model"] != "custom:model-a-0" {
t.Errorf("expected custom:model-a-0, got %s", session["model"])
}
})
@@ -134,11 +135,11 @@ func TestDroidEdit(t *testing.T) {
}
// Check IDs match new indices
if models[0]["id"] != "custom:model-a-[Ollama]-0" {
t.Errorf("expected custom:model-a-[Ollama]-0, got %s", models[0]["id"])
if models[0]["id"] != "custom:model-a-0" {
t.Errorf("expected custom:model-a-0, got %s", models[0]["id"])
}
if models[1]["id"] != "custom:model-c-[Ollama]-1" {
t.Errorf("expected custom:model-c-[Ollama]-1, got %s", models[1]["id"])
if models[1]["id"] != "custom:model-c-1" {
t.Errorf("expected custom:model-c-1, got %s", models[1]["id"])
}
})
@@ -390,13 +391,13 @@ func TestDroidEdit_MalformedModelEntry(t *testing.T) {
t.Fatalf("Edit with malformed entries failed: %v", err)
}
// Malformed entries should be preserved in nonOllamaModels
// Malformed entries (non-object) are dropped - only valid model objects are preserved
settings, _ := readJSONFile(settingsPath)
customModels, _ := settings["customModels"].([]any)
// Should have: 1 new Ollama model + 2 preserved malformed entries
if len(customModels) != 3 {
t.Errorf("expected 3 entries (1 new + 2 preserved malformed), got %d", len(customModels))
// Should have: 1 new Ollama model only (malformed entries dropped)
if len(customModels) != 1 {
t.Errorf("expected 1 entry (malformed entries dropped), got %d", len(customModels))
}
}
@@ -428,6 +429,253 @@ func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
}
}
// testDroidSettingsFixture is a representative settings.json fixture for testing.
// It covers: simple fields, arrays, nested objects, and customModels.
const testDroidSettingsFixture = `{
"commandAllowlist": ["ls", "pwd", "git status"],
"diffMode": "github",
"enableHooks": true,
"hooks": {
"claudeHooksImported": true,
"importedClaudeHooks": ["uv run ruff check", "echo test"]
},
"ideExtensionPromptedAt": {
"cursor": 1763081579486,
"vscode": 1762992990179
},
"customModels": [
{
"model": "existing-ollama-model",
"displayName": "existing-ollama-model",
"baseUrl": "http://localhost:11434/v1",
"apiKey": "ollama",
"provider": "generic-chat-completion-api",
"maxOutputTokens": 64000,
"supportsImages": false,
"id": "custom:existing-ollama-model-0",
"index": 0
},
{
"model": "gpt-4",
"displayName": "GPT-4",
"baseUrl": "https://api.openai.com/v1",
"apiKey": "sk-xxx",
"provider": "openai",
"maxOutputTokens": 4096,
"supportsImages": true,
"id": "openai-gpt4",
"index": 1,
"customField": "should be preserved"
}
],
"sessionDefaultSettings": {
"autonomyMode": "auto-medium",
"model": "custom:existing-ollama-model-0",
"reasoningEffort": "high"
},
"todoDisplayMode": "pinned"
}`
func TestDroidEdit_RoundTrip(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(settingsPath, []byte(testDroidSettingsFixture), 0o644)
// Edit with new models
if err := d.Edit([]string{"llama3", "mistral"}); err != nil {
t.Fatal(err)
}
// Read back and verify
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
// Verify unknown top-level fields preserved
if settings["diffMode"] != "github" {
t.Error("diffMode not preserved")
}
if settings["enableHooks"] != true {
t.Error("enableHooks not preserved")
}
if settings["todoDisplayMode"] != "pinned" {
t.Error("todoDisplayMode not preserved")
}
// Verify arrays preserved
allowlist, ok := settings["commandAllowlist"].([]any)
if !ok || len(allowlist) != 3 {
t.Error("commandAllowlist not preserved")
}
// Verify nested objects preserved
hooks, ok := settings["hooks"].(map[string]any)
if !ok {
t.Fatal("hooks not preserved")
}
if hooks["claudeHooksImported"] != true {
t.Error("hooks.claudeHooksImported not preserved")
}
importedHooks, ok := hooks["importedClaudeHooks"].([]any)
if !ok || len(importedHooks) != 2 {
t.Error("hooks.importedClaudeHooks not preserved")
}
// Verify deeply nested numeric values preserved
idePrompted, ok := settings["ideExtensionPromptedAt"].(map[string]any)
if !ok {
t.Fatal("ideExtensionPromptedAt not preserved")
}
if idePrompted["cursor"] != float64(1763081579486) {
t.Error("ideExtensionPromptedAt.cursor not preserved")
}
// Verify sessionDefaultSettings unknown fields preserved
session, ok := settings["sessionDefaultSettings"].(map[string]any)
if !ok {
t.Fatal("sessionDefaultSettings not preserved")
}
if session["autonomyMode"] != "auto-medium" {
t.Error("sessionDefaultSettings.autonomyMode not preserved")
}
if session["reasoningEffort"] != "high" {
t.Error("sessionDefaultSettings.reasoningEffort not preserved (was valid)")
}
// model should be updated
if session["model"] != "custom:llama3-0" {
t.Errorf("sessionDefaultSettings.model not updated, got %s", session["model"])
}
// Verify customModels: old ollama replaced, non-ollama preserved with extra fields
models, ok := settings["customModels"].([]any)
if !ok {
t.Fatal("customModels not preserved")
}
if len(models) != 3 { // 2 new ollama + 1 non-ollama
t.Fatalf("expected 3 models, got %d", len(models))
}
// First two should be new Ollama models
m0 := models[0].(map[string]any)
if m0["model"] != "llama3" || m0["apiKey"] != "ollama" {
t.Error("first model should be llama3")
}
m1 := models[1].(map[string]any)
if m1["model"] != "mistral" || m1["apiKey"] != "ollama" {
t.Error("second model should be mistral")
}
// Third should be preserved non-Ollama with extra field
m2 := models[2].(map[string]any)
if m2["model"] != "gpt-4" {
t.Error("non-Ollama model not preserved")
}
if m2["customField"] != "should be preserved" {
t.Error("non-Ollama model's extra field not preserved")
}
}
func TestDroidEdit_PreservesUnknownFields(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
readSettings := func() map[string]any {
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
return settings
}
t.Run("preserves all JSON value types", func(t *testing.T) {
os.RemoveAll(settingsDir)
os.MkdirAll(settingsDir, 0o755)
original := `{
"stringField": "value",
"numberField": 42,
"floatField": 3.14,
"boolField": true,
"nullField": null,
"arrayField": [1, "two", true],
"objectField": {"nested": "value"},
"customModels": [],
"sessionDefaultSettings": {}
}`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
settings := readSettings()
if settings["stringField"] != "value" {
t.Error("stringField not preserved")
}
if settings["numberField"] != float64(42) {
t.Error("numberField not preserved")
}
if settings["floatField"] != 3.14 {
t.Error("floatField not preserved")
}
if settings["boolField"] != true {
t.Error("boolField not preserved")
}
if settings["nullField"] != nil {
t.Error("nullField not preserved")
}
arr, ok := settings["arrayField"].([]any)
if !ok || len(arr) != 3 {
t.Error("arrayField not preserved")
}
obj, ok := settings["objectField"].(map[string]any)
if !ok || obj["nested"] != "value" {
t.Error("objectField not preserved")
}
})
t.Run("preserves extra fields in non-Ollama models", func(t *testing.T) {
os.RemoveAll(settingsDir)
os.MkdirAll(settingsDir, 0o755)
original := `{
"customModels": [{
"model": "gpt-4",
"apiKey": "sk-xxx",
"extraField": "preserved",
"nestedExtra": {"foo": "bar"}
}]
}`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"llama3"}); err != nil {
t.Fatal(err)
}
settings := readSettings()
models := settings["customModels"].([]any)
gpt4 := models[1].(map[string]any) // non-Ollama is second
if gpt4["extraField"] != "preserved" {
t.Error("extraField not preserved")
}
nested := gpt4["nestedExtra"].(map[string]any)
if nested["foo"] != "bar" {
t.Error("nestedExtra not preserved")
}
})
}
func TestIsValidReasoningEffort(t *testing.T) {
tests := []struct {
effort string
@@ -452,3 +700,603 @@ func TestIsValidReasoningEffort(t *testing.T) {
})
}
}
func TestDroidEdit_Idempotent(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(settingsPath, []byte(testDroidSettingsFixture), 0o644)
// Edit twice with same models
d.Edit([]string{"llama3", "mistral"})
firstData, _ := os.ReadFile(settingsPath)
d.Edit([]string{"llama3", "mistral"})
secondData, _ := os.ReadFile(settingsPath)
// Results should be identical
if string(firstData) != string(secondData) {
t.Error("repeated edits with same models produced different results")
}
}
func TestDroidEdit_MultipleConsecutiveEdits(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(settingsPath, []byte(testDroidSettingsFixture), 0o644)
// Multiple edits shouldn't accumulate garbage or lose data
for i := range 10 {
models := []string{"model-a", "model-b"}
if i%2 == 0 {
models = []string{"model-x", "model-y", "model-z"}
}
if err := d.Edit(models); err != nil {
t.Fatalf("edit %d failed: %v", i, err)
}
}
// Verify file is still valid JSON and preserves original fields
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
}
// Original fields should still be there
if settings["diffMode"] != "github" {
t.Error("diffMode lost after multiple edits")
}
if settings["enableHooks"] != true {
t.Error("enableHooks lost after multiple edits")
}
// Non-Ollama model should still be preserved
models := settings["customModels"].([]any)
foundOther := false
for _, m := range models {
if entry, ok := m.(map[string]any); ok {
if entry["model"] == "gpt-4" {
foundOther = true
if entry["customField"] != "should be preserved" {
t.Error("other customField lost after multiple edits")
}
}
}
}
if !foundOther {
t.Error("other model lost after multiple edits")
}
}
func TestDroidEdit_UnicodeAndSpecialCharacters(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// Settings with unicode and special characters
original := `{
"userName": "日本語テスト",
"emoji": "🚀🎉💻",
"specialChars": "quotes: \"test\" and 'test', backslash: \\, newline: \n, tab: \t",
"unicodeEscape": "\u0048\u0065\u006c\u006c\u006f",
"customModels": [],
"sessionDefaultSettings": {}
}`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
if settings["userName"] != "日本語テスト" {
t.Error("Japanese characters not preserved")
}
if settings["emoji"] != "🚀🎉💻" {
t.Error("emoji not preserved")
}
// Note: JSON encoding will normalize escape sequences
if settings["unicodeEscape"] != "Hello" {
t.Error("unicode escape sequence not preserved")
}
}
func TestDroidEdit_LargeNumbers(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// Large numbers and timestamps (common in settings files)
original := `{
"timestamp": 1763081579486,
"largeInt": 9007199254740991,
"negativeNum": -12345,
"floatNum": 3.141592653589793,
"scientificNotation": 1.23e10,
"customModels": [],
"sessionDefaultSettings": {}
}`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
if settings["timestamp"] != float64(1763081579486) {
t.Errorf("timestamp not preserved: got %v", settings["timestamp"])
}
if settings["largeInt"] != float64(9007199254740991) {
t.Errorf("largeInt not preserved: got %v", settings["largeInt"])
}
if settings["negativeNum"] != float64(-12345) {
t.Error("negativeNum not preserved")
}
if settings["floatNum"] != 3.141592653589793 {
t.Error("floatNum not preserved")
}
}
func TestDroidEdit_EmptyAndNullValues(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
original := `{
"emptyString": "",
"nullValue": null,
"emptyArray": [],
"emptyObject": {},
"falseBool": false,
"zeroNumber": 0,
"customModels": [],
"sessionDefaultSettings": {}
}`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
if settings["emptyString"] != "" {
t.Error("emptyString not preserved")
}
if settings["nullValue"] != nil {
t.Error("nullValue not preserved as null")
}
if arr, ok := settings["emptyArray"].([]any); !ok || len(arr) != 0 {
t.Error("emptyArray not preserved")
}
if obj, ok := settings["emptyObject"].(map[string]any); !ok || len(obj) != 0 {
t.Error("emptyObject not preserved")
}
if settings["falseBool"] != false {
t.Error("falseBool not preserved (false vs missing)")
}
if settings["zeroNumber"] != float64(0) {
t.Error("zeroNumber not preserved")
}
}
func TestDroidEdit_DeeplyNestedStructures(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
original := `{
"level1": {
"level2": {
"level3": {
"level4": {
"deepValue": "found me",
"deepArray": [1, 2, {"nested": true}]
}
}
}
},
"customModels": [],
"sessionDefaultSettings": {}
}`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
// Navigate to deeply nested value
l1 := settings["level1"].(map[string]any)
l2 := l1["level2"].(map[string]any)
l3 := l2["level3"].(map[string]any)
l4 := l3["level4"].(map[string]any)
if l4["deepValue"] != "found me" {
t.Error("deeply nested value not preserved")
}
deepArray := l4["deepArray"].([]any)
if len(deepArray) != 3 {
t.Error("deeply nested array not preserved")
}
nestedInArray := deepArray[2].(map[string]any)
if nestedInArray["nested"] != true {
t.Error("object nested in array not preserved")
}
}
func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
// Test model names with colons, slashes, special chars
specialModels := []string{
"qwen3:480b-cloud",
"llama3.2:70b",
"model/with/slashes",
"model-with-dashes",
"model_with_underscores",
}
if err := d.Edit(specialModels); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
models := settings["customModels"].([]any)
if len(models) != len(specialModels) {
t.Fatalf("expected %d models, got %d", len(specialModels), len(models))
}
for i, expected := range specialModels {
m := models[i].(map[string]any)
if m["model"] != expected {
t.Errorf("model %d: expected %s, got %s", i, expected, m["model"])
}
}
}
func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// No customModels key at all
original := `{
"diffMode": "github",
"sessionDefaultSettings": {"autonomyMode": "auto-high"}
}`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
// Original fields preserved
if settings["diffMode"] != "github" {
t.Error("diffMode not preserved")
}
// customModels created
models, ok := settings["customModels"].([]any)
if !ok || len(models) != 1 {
t.Error("customModels not created properly")
}
}
func TestDroidEdit_NullCustomModels(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
original := `{
"customModels": null,
"sessionDefaultSettings": {}
}`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
models, ok := settings["customModels"].([]any)
if !ok || len(models) != 1 {
t.Error("null customModels not handled properly")
}
}
func TestDroidEdit_MinifiedJSON(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// Minified JSON (no whitespace)
original := `{"diffMode":"github","enableHooks":true,"hooks":{"imported":["cmd1","cmd2"]},"customModels":[],"sessionDefaultSettings":{}}`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatal("output is not valid JSON")
}
if settings["diffMode"] != "github" {
t.Error("diffMode not preserved from minified JSON")
}
if settings["enableHooks"] != true {
t.Error("enableHooks not preserved from minified JSON")
}
}
func TestDroidEdit_CreatesDirectoryIfMissing(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
// Directory doesn't exist
if _, err := os.Stat(settingsDir); !os.IsNotExist(err) {
t.Fatal("directory should not exist before test")
}
if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
// Directory should be created
if _, err := os.Stat(settingsDir); os.IsNotExist(err) {
t.Fatal("directory was not created")
}
// File should exist and be valid
settingsPath := filepath.Join(settingsDir, "settings.json")
data, err := os.ReadFile(settingsPath)
if err != nil {
t.Fatal("settings file not created")
}
var settings map[string]any
if err := json.Unmarshal(data, &settings); err != nil {
t.Fatal("created file is not valid JSON")
}
}
func TestDroidEdit_PreservesFileAfterError(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// Valid original content
original := `{"diffMode": "github", "customModels": [], "sessionDefaultSettings": {}}`
os.WriteFile(settingsPath, []byte(original), 0o644)
// Empty models list is a no-op, should not modify file
d.Edit([]string{})
data, _ := os.ReadFile(settingsPath)
if string(data) != original {
t.Error("file was modified when it should not have been")
}
}
func TestDroidEdit_BackupCreated(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
os.MkdirAll(settingsDir, 0o755)
// Use a unique marker to identify our backup
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
original := fmt.Sprintf(`{"diffMode": "%s", "customModels": [], "sessionDefaultSettings": {}}`, uniqueMarker)
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
// Find backup containing our unique marker
backups, _ := filepath.Glob(filepath.Join(backupDir, "settings.json.*"))
foundBackup := false
for _, backup := range backups {
data, err := os.ReadFile(backup)
if err != nil {
continue
}
if string(data) == original {
foundBackup = true
break
}
}
if !foundBackup {
t.Error("backup with original content not found")
}
// Main file should be modified
newData, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(newData, &settings)
models := settings["customModels"].([]any)
if len(models) != 1 {
t.Error("main file was not updated")
}
}
func TestDroidEdit_LargeNumberOfModels(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
os.WriteFile(settingsPath, []byte(`{"customModels": [], "sessionDefaultSettings": {}}`), 0o644)
// Add many models
var models []string
for i := range 100 {
models = append(models, fmt.Sprintf("model-%d", i))
}
if err := d.Edit(models); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
customModels := settings["customModels"].([]any)
if len(customModels) != 100 {
t.Errorf("expected 100 models, got %d", len(customModels))
}
// Verify indices are correct
for i, m := range customModels {
entry := m.(map[string]any)
if entry["index"] != float64(i) {
t.Errorf("model %d has wrong index: %v", i, entry["index"])
}
}
}
func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
d := &Droid{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
settingsDir := filepath.Join(tmpDir, ".factory")
settingsPath := filepath.Join(settingsDir, "settings.json")
os.MkdirAll(settingsDir, 0o755)
// Arrays with mixed types (valid JSON)
original := `{
"mixedArray": [1, "two", true, null, {"nested": "obj"}, [1,2,3]],
"customModels": [],
"sessionDefaultSettings": {}
}`
os.WriteFile(settingsPath, []byte(original), 0o644)
if err := d.Edit([]string{"model-a"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(settingsPath)
var settings map[string]any
json.Unmarshal(data, &settings)
arr := settings["mixedArray"].([]any)
if len(arr) != 6 {
t.Error("mixedArray length not preserved")
}
if arr[0] != float64(1) {
t.Error("number in mixed array not preserved")
}
if arr[1] != "two" {
t.Error("string in mixed array not preserved")
}
if arr[2] != true {
t.Error("bool in mixed array not preserved")
}
if arr[3] != nil {
t.Error("null in mixed array not preserved")
}
if nested, ok := arr[4].(map[string]any); !ok || nested["nested"] != "obj" {
t.Error("object in mixed array not preserved")
}
if innerArr, ok := arr[5].([]any); !ok || len(innerArr) != 3 {
t.Error("array in mixed array not preserved")
}
}

View File

@@ -34,6 +34,7 @@ type Editor interface {
// Edit updates the config files for the integration with the given models
Edit(models []string) error
// Models returns the models currently configured for the integration
// TODO(parthsareen): add error return to Models()
Models() []string
}
@@ -229,15 +230,15 @@ func runIntegration(name, modelName string) error {
return r.Run(modelName)
}
// ConfigCmd returns the cobra command for configuring integrations.
func ConfigCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
// LaunchCmd returns the cobra command for launching integrations.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
var modelFlag string
var launchFlag bool
var configFlag bool
cmd := &cobra.Command{
Use: "config [INTEGRATION]",
Short: "Configure an external integration to use Ollama",
Long: `Configure an external application to use Ollama models.
Use: "launch [INTEGRATION]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
Supported integrations:
claude Claude Code
@@ -246,9 +247,10 @@ Supported integrations:
opencode OpenCode
Examples:
ollama config
ollama config claude
ollama config droid --launch`,
ollama launch
ollama launch claude
ollama launch claude --model <model>
ollama launch droid --config (does not auto-launch)`,
Args: cobra.MaximumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
@@ -271,8 +273,8 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
// If --launch without --model, use saved config if available
if launchFlag && modelFlag == "" {
// If launching without --model, use saved config if available
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0])
}
@@ -333,29 +335,19 @@ Examples:
}
}
if slices.ContainsFunc(models, func(m string) bool {
return !strings.HasSuffix(m, "cloud")
}) {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "Coding agents work best with at least 64k context. Either:")
fmt.Fprintln(os.Stderr, " - Set the context slider in Ollama app settings")
fmt.Fprintln(os.Stderr, " - Run: OLLAMA_CONTEXT_LENGTH=64000 ollama serve")
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0])
}
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
return nil
}
if launchFlag {
return runIntegration(name, models[0])
}
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0])
}
fmt.Fprintf(os.Stderr, "Run 'ollama config %s --launch' to start with %s\n", strings.ToLower(name), models[0])
return nil
return runIntegration(name, models[0])
},
}
cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
cmd.Flags().BoolVar(&launchFlag, "launch", false, "Launch the integration after configuring")
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
return cmd
}

View File

@@ -81,17 +81,17 @@ func TestHasLocalModel(t *testing.T) {
}
}
func TestConfigCmd(t *testing.T) {
func TestLaunchCmd(t *testing.T) {
// Mock checkServerHeartbeat that always succeeds
mockCheck := func(cmd *cobra.Command, args []string) error {
return nil
}
cmd := ConfigCmd(mockCheck)
cmd := LaunchCmd(mockCheck)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "config [INTEGRATION]" {
t.Errorf("Use = %q, want %q", cmd.Use, "config [INTEGRATION]")
if cmd.Use != "launch [INTEGRATION]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
@@ -107,9 +107,9 @@ func TestConfigCmd(t *testing.T) {
t.Error("--model flag should exist")
}
launchFlag := cmd.Flags().Lookup("launch")
if launchFlag == nil {
t.Error("--launch flag should exist")
configFlag := cmd.Flags().Lookup("config")
if configFlag == nil {
t.Error("--config flag should exist")
}
})
@@ -158,11 +158,11 @@ func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
}
}
func TestConfigCmd_NilHeartbeat(t *testing.T) {
func TestLaunchCmd_NilHeartbeat(t *testing.T) {
// This should not panic - cmd creation should work even with nil
cmd := ConfigCmd(nil)
cmd := LaunchCmd(nil)
if cmd == nil {
t.Fatal("ConfigCmd returned nil")
t.Fatal("LaunchCmd returned nil")
}
// PreRunE should be nil when passed nil

View File

@@ -105,17 +105,26 @@ func (o *OpenCode) Edit(modelList []string) error {
for name, cfg := range models {
if cfgMap, ok := cfg.(map[string]any); ok {
if displayName, ok := cfgMap["name"].(string); ok {
if strings.HasSuffix(displayName, "[Ollama]") && !selectedSet[name] {
delete(models, name)
}
if isOllamaModel(cfgMap) && !selectedSet[name] {
delete(models, name)
}
}
}
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
if isOllamaModel(existing) {
existing["_launch"] = true
if name, ok := existing["name"].(string); ok {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
continue
}
models[model] = map[string]any{
"name": fmt.Sprintf("%s [Ollama]", model),
"name": model,
"_launch": true,
}
}
@@ -201,3 +210,15 @@ func (o *OpenCode) Models() []string {
slices.Sort(keys)
return keys
}
// isOllamaModel reports whether a model config entry is managed by us
func isOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
}
// previously used [Ollama] as a suffix for the model managed by ollama launch
if name, ok := cfg["name"].(string); ok {
return strings.HasSuffix(name, "[Ollama]")
}
return false
}

View File

@@ -161,6 +161,76 @@ func TestOpenCodeEdit(t *testing.T) {
assertOpenCodeModelNotExists(t, configPath, "mistral")
})
t.Run("preserve user customizations on managed models", func(t *testing.T) {
cleanup()
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
// Add custom fields to the model entry (simulating user edits)
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry := models["llama3.2"].(map[string]any)
entry["_myPref"] = "custom-value"
entry["_myNum"] = 42
configData, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configPath, configData, 0o644)
// Re-run Edit — should preserve custom fields
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ = os.ReadFile(configPath)
json.Unmarshal(data, &cfg)
provider = cfg["provider"].(map[string]any)
ollama = provider["ollama"].(map[string]any)
models = ollama["models"].(map[string]any)
entry = models["llama3.2"].(map[string]any)
if entry["_myPref"] != "custom-value" {
t.Errorf("_myPref was lost: got %v", entry["_myPref"])
}
if entry["_myNum"] != float64(42) {
t.Errorf("_myNum was lost: got %v", entry["_myNum"])
}
if v, ok := entry["_launch"].(bool); !ok || !v {
t.Errorf("_launch marker missing or false: got %v", entry["_launch"])
}
})
t.Run("migrate legacy [Ollama] suffix entries", func(t *testing.T) {
cleanup()
// Write a config with a legacy entry (has [Ollama] suffix but no _launch marker)
os.MkdirAll(configDir, 0o755)
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"llama3.2":{"name":"llama3.2 [Ollama]"}}}}}`), 0o644)
if err := o.Edit([]string{"llama3.2"}); err != nil {
t.Fatal(err)
}
data, _ := os.ReadFile(configPath)
var cfg map[string]any
json.Unmarshal(data, &cfg)
provider := cfg["provider"].(map[string]any)
ollama := provider["ollama"].(map[string]any)
models := ollama["models"].(map[string]any)
entry := models["llama3.2"].(map[string]any)
// _launch marker should be added
if v, ok := entry["_launch"].(bool); !ok || !v {
t.Errorf("_launch marker not added during migration: got %v", entry["_launch"])
}
// [Ollama] suffix should be stripped
if name, ok := entry["name"].(string); !ok || name != "llama3.2" {
t.Errorf("name suffix not stripped: got %q", entry["name"])
}
})
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
cleanup()
os.MkdirAll(configDir, 0o755)

View File

@@ -465,7 +465,7 @@ func confirmPrompt(prompt string) (bool, error) {
}
defer term.Restore(fd, oldState)
fmt.Fprintf(os.Stderr, "%s [y/n] ", prompt)
fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
buf := make([]byte, 1)
for {

View File

@@ -6,6 +6,10 @@ import (
"log/slog"
"regexp"
"strconv"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
@@ -69,6 +73,9 @@ func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
kv["glm4moelite.attention.key_length_mla"] = p.KVLoraRank + p.QKRopeHeadDim
kv["glm4moelite.attention.value_length_mla"] = p.KVLoraRank
kv["tokenizer.ggml.pre"] = "glm4"
return kv
@@ -100,6 +107,67 @@ func (p *glm4MoeLiteModel) Replacements() []string {
}
}
// repackKVB extracts K or V from the combined KV_B tensor for MLA absorption.
// K output row-major: [n_head, kv_lora_rank, qk_nope] -> GGML ne[]={qk_nope, kv_lora_rank, n_head}
// V output row-major: [n_head, v_head, kv_lora_rank] -> GGML ne[]={kv_lora_rank, v_head, n_head}
func (p *glm4MoeLiteModel) repackKVB(extractK bool, kvFirst bool, numHeads int) Repacker {
qkNope := int(p.QKNopeHeadDim)
vHeadDim := int(p.VHeadDim)
kvLoraRank := int(p.KVLoraRank)
kvPerHead := qkNope + vHeadDim
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
var err error
// Normalize to [n_head * (qk_nope + v_head), kv_lora_rank] layout
if kvFirst {
tt, err = tensor.Transpose(tt, 1, 0)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
}
// Reshape to [n_head, qk_nope + v_head, kv_lora_rank]
if err := tt.Reshape(numHeads, kvPerHead, kvLoraRank); err != nil {
return nil, err
}
if extractK {
// Slice K: [n_head, qk_nope, kv_lora_rank]
tt, err = tt.Slice(nil, tensor.S(0, qkNope), nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
// Transpose to [n_head, kv_lora_rank, qk_nope]
tt, err = tensor.Transpose(tt, 0, 2, 1)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
} else {
// Slice V: [n_head, v_head, kv_lora_rank] - already correct layout
tt, err = tt.Slice(nil, tensor.S(qkNope, kvPerHead), nil)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
}
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
}
}
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
@@ -139,6 +207,52 @@ func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
// Split attn_kv_b into separate attn_k_b and attn_v_b for MLA absorption
if strings.HasSuffix(t.Name(), ".attn_kv_b.weight") {
qkNope := int(p.QKNopeHeadDim)
vHeadDim := int(p.VHeadDim)
kvLoraRank := int(p.KVLoraRank)
kvPerHead := qkNope + vHeadDim
numHeads := int(p.NumAttentionHeads)
kvFirst := true
if len(t.Shape()) == 2 {
switch {
case int(t.Shape()[0]) == kvLoraRank:
if kvPerHead > 0 && int(t.Shape()[1])%kvPerHead == 0 {
numHeads = int(t.Shape()[1]) / kvPerHead
}
kvFirst = true
case int(t.Shape()[1]) == kvLoraRank:
if kvPerHead > 0 && int(t.Shape()[0])%kvPerHead == 0 {
numHeads = int(t.Shape()[0]) / kvPerHead
}
kvFirst = false
default:
slog.Warn("glm4moelite: unexpected attn_kv_b layout", "name", t.Name(), "shape", t.Shape())
}
}
kTensor := t.Clone()
kTensor.SetRepacker(p.repackKVB(true, kvFirst, numHeads))
out = append(out, &ggml.Tensor{
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_k_b", 1),
Kind: t.Kind(),
Shape: []uint64{uint64(numHeads), uint64(kvLoraRank), uint64(qkNope)},
WriterTo: kTensor,
})
vTensor := t.Clone()
vTensor.SetRepacker(p.repackKVB(false, kvFirst, numHeads))
out = append(out, &ggml.Tensor{
Name: strings.Replace(t.Name(), "attn_kv_b", "attn_v_b", 1),
Kind: t.Kind(),
Shape: []uint64{uint64(numHeads), uint64(vHeadDim), uint64(kvLoraRank)},
WriterTo: vTensor,
})
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),

View File

@@ -4,16 +4,6 @@ title: Anthropic compatibility
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
## Recommended models
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
Pull a model before use:
```shell
ollama pull qwen3-coder
ollama pull glm-4.7:cloud
```
## Usage
### Environment variables
@@ -22,8 +12,8 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_API_KEY="" # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
@@ -245,10 +235,41 @@ curl -X POST http://localhost:11434/v1/messages \
## Using with Claude Code
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend.
### Recommended models
For coding use cases, models like `glm-4.7`, `minimax-m2.1`, and `qwen3-coder` are recommended.
Download a model before use:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
ollama pull qwen3-coder
```
> Note: Qwen 3 coder is a 30B parameter model requiring at least 24GB of VRAM to run smoothly. More is required for longer context lengths.
```shell
ollama pull glm-4.7:cloud
```
### Quick setup
```shell
ollama launch claude
```
This will prompt you to select a model, configure Claude Code automatically, and launch it. To configure without launching:
```shell
ollama launch claude --config
```
### Manual setup
Set the environment variables and run Claude Code:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
@@ -256,19 +277,13 @@ Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
export ANTHROPIC_API_KEY=""
```
Then run Claude Code with any Ollama model:
```shell
# Local models
claude --model qwen3-coder
claude --model gpt-oss:20b
# Cloud models
claude --model glm-4.7:cloud
claude --model minimax-m2.1:cloud
```
## Endpoints

View File

@@ -8,6 +8,47 @@ title: CLI Reference
ollama run gemma3
```
### Launch integrations
```
ollama launch
```
Configure and launch external applications to use Ollama models. This provides an interactive way to set up and start integrations with supported apps.
#### Supported integrations
- **OpenCode** - Open-source coding assistant
- **Claude Code** - Anthropic's agentic coding tool
- **Codex** - OpenAI's coding assistant
- **Droid** - Factory's AI coding agent
#### Examples
Launch an integration interactively:
```
ollama launch
```
Launch a specific integration:
```
ollama launch claude
```
Launch with a specific model:
```
ollama launch claude --model qwen3-coder
```
Configure without launching:
```
ollama launch droid --config
```
#### Multiline input
For multiline input, you can wrap text with `"""`:

View File

@@ -3,8 +3,6 @@ title: Cloud
sidebarTitle: Cloud
---
<Info>Ollama's cloud is currently in preview.</Info>
## Cloud Models
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.

View File

@@ -8,7 +8,7 @@ Context length is the maximum number of tokens that the model has access to in m
The default context length in Ollama is 4096 tokens.
</Note>
Tasks which require large context like web search, agents, and coding tools should be set to at least 32000 tokens.
Tasks which require large context like web search, agents, and coding tools should be set to at least 64000 tokens.
## Setting context length
@@ -24,7 +24,7 @@ Change the slider in the Ollama app under settings to your desired context lengt
### CLI
If editing the context length for Ollama is not possible, the context length can also be updated when serving Ollama.
```
OLLAMA_CONTEXT_LENGTH=32000 ollama serve
OLLAMA_CONTEXT_LENGTH=64000 ollama serve
```
### Check allocated context length and model offloading

View File

@@ -102,18 +102,19 @@
"group": "Integrations",
"pages": [
"/integrations/claude-code",
"/integrations/vscode",
"/integrations/jetbrains",
"/integrations/codex",
"/integrations/cline",
"/integrations/codex",
"/integrations/droid",
"/integrations/goose",
"/integrations/zed",
"/integrations/roo-code",
"/integrations/jetbrains",
"/integrations/marimo",
"/integrations/n8n",
"/integrations/xcode",
"/integrations/onyx",
"/integrations/marimo"
"/integrations/opencode",
"/integrations/roo-code",
"/integrations/vscode",
"/integrations/xcode",
"/integrations/zed"
]
},
{

View File

@@ -9,7 +9,7 @@ sidebarTitle: Welcome
<CardGroup cols={2}>
<Card title="Quickstart" icon="rocket" href="/quickstart">
Get up and running with your first model
Get up and running with your first model or integrate Ollama with your favorite tools
</Card>
<Card
title="Download Ollama"

View File

@@ -4,7 +4,7 @@ title: Claude Code
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `glm-4.7`, `qwen3-coder`, `gpt-oss`.
![Claude Code with Ollama](https://files.ollama.com/claude-code.png)
@@ -26,12 +26,27 @@ irm https://claude.ai/install.ps1 | iex
## Usage with Ollama
### Quick setup
```shell
ollama launch claude
```
To configure without launching:
```shell
ollama launch claude --config
```
### Manual setup
Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_API_KEY=""
export ANTHROPIC_BASE_URL=http://localhost:11434
```
@@ -44,35 +59,17 @@ claude --model gpt-oss:20b
Or run with environment variables inline:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
```
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
2. Set the environment variables:
```shell
export ANTHROPIC_BASE_URL=https://ollama.com
export ANTHROPIC_API_KEY=<your-api-key>
```
3. Run Claude Code with a cloud model:
```shell
claude --model glm-4.7:cloud
```
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
## Recommended Models
### Cloud models
- `glm-4.7:cloud` - High-performance cloud model
- `minimax-m2.1:cloud` - Fast cloud model
- `qwen3-coder:480b` - Large coding model
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks

View File

@@ -13,7 +13,21 @@ npm install -g @openai/codex
## Usage with Ollama
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 64k tokens.</Note>
### Quick setup
```
ollama launch codex
```
To configure without launching:
```shell
ollama launch codex --config
```
### Manual setup
To use `codex` with Ollama, use the `--oss` flag:

View File

@@ -11,10 +11,24 @@ Install the [Droid CLI](https://factory.ai/):
curl -fsSL https://app.factory.ai/cli | sh
```
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch droid
```
To configure without launching:
```shell
ollama launch droid --config
```
### Manual setup
Add a local configuration block to `~/.factory/config.json`:
```json
@@ -73,4 +87,4 @@ Add the cloud configuration block to `~/.factory/config.json`:
}
```
Run `droid` in a new terminal to load the new settings.
Run `droid` in a new terminal to load the new settings.

View File

@@ -0,0 +1,106 @@
---
title: OpenCode
---
OpenCode is an open-source AI coding assistant that runs in your terminal.
## Install
Install the [OpenCode CLI](https://opencode.ai):
```bash
curl -fsSL https://opencode.ai/install.sh | bash
```
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch opencode
```
To configure without launching:
```shell
ollama launch opencode --config
```
### Manual setup
Add a configuration block to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"qwen3-coder": {
"name": "qwen3-coder"
}
}
}
}
}
```
## Cloud Models
`glm-4.7:cloud` is the recommended model for use with OpenCode.
Add the cloud configuration to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud"
}
}
}
}
}
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama Cloud",
"options": {
"baseURL": "https://ollama.com/v1"
},
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud"
}
}
}
}
}
```
Run `opencode` in a new terminal to load the new settings.

View File

@@ -18,13 +18,13 @@ This quickstart will walk your through running your first model with Ollama. To
<Tab title="CLI">
Open a terminal and run the command:
```
```sh
ollama run gemma3
```
</Tab>
<Tab title="cURL">
```
```sh
ollama pull gemma3
```
@@ -45,13 +45,13 @@ This quickstart will walk your through running your first model with Ollama. To
<Tab title="Python">
Start by downloading a model:
```
```sh
ollama pull gemma3
```
Then install Ollama's Python library:
```
```sh
pip install ollama
```
@@ -101,3 +101,42 @@ This quickstart will walk your through running your first model with Ollama. To
</Tabs>
See a full list of available models [here](https://ollama.com/models).
## Coding
For coding use cases, we recommend using the `glm-4.7-flash` model.
Note: this model requires 23 GB of VRAM with 64000 tokens context length.
```sh
ollama pull glm-4.7-flash
```
Alternatively, you can use a more powerful cloud model (with full context length):
```sh
ollama pull glm-4.7:cloud
```
Use `ollama launch` to quickly set up a coding tool with Ollama models:
```sh
ollama launch
```
### Supported integrations
- [OpenCode](/integrations/opencode) - Open-source coding assistant
- [Claude Code](/integrations/claude-code) - Anthropic's agentic coding tool
- [Codex](/integrations/codex) - OpenAI's coding assistant
- [Droid](/integrations/droid) - Factory's AI coding agent
### Launch with a specific model
```sh
ollama launch claude --model glm-4.7-flash
```
### Configure without launching
```sh
ollama launch claude --config
```

View File

@@ -0,0 +1,309 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: nobody <>
Date: Sat, 24 Jan 2026 02:31:01 +0000
Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash
Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
uses head size 576 with gqa_ratio 4, which was previously only supported
for gqa_ratio 16 (DeepSeek).
Metal changes:
- Enable head size 576 for flash attention
- Increase simdgroups to 8 for large heads (>=512)
- Add case 8 kernel dispatch for 8 simdgroups
CUDA changes:
- Add gqa_ratio 4 support for head 576/512
- Add tile configs for (576, 512, 4) and (576, 512, 8)
- Add MMA config cases for ncols 4
- Add template instances for ncols2=4
- Fix nbatch_fa values in nvidia_fp32 config (32->64)
---
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 40 +++++++++++++++----
ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++
ggml/src/ggml-cuda/fattn.cu | 12 ++++--
...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
ggml/src/ggml-metal/ggml-metal-device.m | 8 +---
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
ggml/src/ggml-metal/ggml-metal.metal | 1 +
10 files changed, 64 insertions(+), 19 deletions(-)
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 7bd1044c1..3dea2205e 100644
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
}
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
} else {
- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
- // Wide version of KQ_C is column-major => swap A and B.
- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+ if constexpr (cols_per_warp == 8) {
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+ } else {
+ // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE)
+ // RDNA matrix C is column-major.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+#else
+ // swap A and B for CUDA.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE)
+ }
}
}
}
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
+#ifdef VOLTA_MMA_AVAILABLE
+ if (ncols1*ncols2 < 32) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif // VOLTA_MMA_AVAILABLE
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
+
+// For GLM 4.7 Flash
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
index 7c4d6fe67..371be7442 100644
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
return 0;
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
return 0;
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
return;
}
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
+ return;
+ }
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+ return;
+ }
}
if constexpr (DV <= 256) {
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 015540666..1693479cb 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 576: {
- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
+ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
GGML_ASSERT(V->ne[0] == 512);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
- GGML_ASSERT(gqa_ratio % 16 == 0);
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ GGML_ASSERT(gqa_ratio % 4 == 0);
+ if (gqa_ratio % 16 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ }
} break;
default:
GGML_ABORT("fatal error");
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
- if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
+ if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
index 2074e954a..517993cb0 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
index 24c64cf00..97b19c67a 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
index 1ada657f1..989626dfa 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
index 86d4ffae2..173de7aac 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
index f24270bb1..7b5ee968c 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.m
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 112 &&
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
- op->src[0]->ne[0] != 256) {
- return false;
- }
- if (op->src[0]->ne[0] == 576) {
- // DeepSeek sizes
- // TODO: disabled for now, until optmized
+ op->src[0]->ne[0] != 256 &&
+ op->src[0]->ne[0] != 576) {
return false;
}
if (op->src[1]->type != op->src[2]->type) {
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index e99c1763f..80864f303 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
// simdgroups per threadgroup (a.k.a. warps)
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
- int32_t nsg = 4;
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
const size_t smem = FATTN_SMEM(nsg);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index c98d269d1..d33c16079 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
+ case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS

View File

@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
}
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
} else {
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
// Wide version of KQ_C is column-major => swap A and B.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
if constexpr (cols_per_warp == 8) {
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
} else {
// Wide version of KQ_C is column-major
#if defined(AMD_WMMA_AVAILABLE)
// RDNA matrix C is column-major.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
#else
// swap A and B for CUDA.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
#endif // defined(AMD_WMMA_AVAILABLE)
}
}
}
}
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
#ifdef VOLTA_MMA_AVAILABLE
if (ncols1*ncols2 < 32) {
NO_DEVICE_CODE;
return;
}
#endif // VOLTA_MMA_AVAILABLE
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
// For GLM 4.7 Flash
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);

View File

@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
return 0;
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
return 0;
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 8 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 4 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
return;
}
}
if constexpr (DV <= 256) {

View File

@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 576: {
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
// For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
GGML_ASSERT(V->ne[0] == 512);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
GGML_ASSERT(gqa_ratio % 4 == 0);
if (gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
}
} break;
default:
GGML_ABORT("fatal error");
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;

View File

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);

View File

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);

View File

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);

View File

@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);

View File

@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 112 &&
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
op->src[0]->ne[0] != 256) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek sizes
// TODO: disabled for now, until optmized
op->src[0]->ne[0] != 256 &&
op->src[0]->ne[0] != 576) {
return false;
}
if (op->src[1]->type != op->src[2]->type) {

View File

@@ -8967,6 +8967,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS

View File

@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
// simdgroups per threadgroup (a.k.a. warps)
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
int32_t nsg = 4;
int32_t nsg = ne00 >= 512 ? 8 : 4;
const size_t smem = FATTN_SMEM(nsg);

View File

@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS

View File

@@ -39,6 +39,13 @@ type Model interface {
Config() config
}
// Validator is an optional interface that models can implement to perform
// validation after tensors have been loaded. If validation fails, model
// loading will fail with the returned error.
type Validator interface {
Validate() error
}
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
@@ -116,6 +123,13 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem()))
if validator, ok := m.(Validator); ok {
if err := validator.Validate(); err != nil {
return nil, err
}
}
return m, nil
}

View File

@@ -1,6 +1,7 @@
package glm4moelite
import (
"errors"
"math"
"github.com/ollama/ollama/fs"
@@ -11,6 +12,8 @@ import (
"github.com/ollama/ollama/model/input"
)
var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
type Options struct {
numExpertsUsed int
numExperts int
@@ -47,7 +50,9 @@ type Attention struct {
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
KVB *nn.Linear `gguf:"attn_kv_b"`
KB *nn.Linear `gguf:"attn_k_b"`
VB *nn.Linear `gguf:"attn_v_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
@@ -78,15 +83,16 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
// MLA absorption: absorb K projection into query
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
qPassAbsorb := attn.KB.Forward(ctx, qPass).Permute(ctx, 0, 2, 1, 3)
query = qRot.Concat(ctx, qPassAbsorb, 0)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
key := kRot.Concat(ctx, kPass, 0)
attention := nn.AttentionWithVMLA(ctx, query, key, kPass, nil, attn.VB.Weight, opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
@@ -217,7 +223,6 @@ func New(c fs.Config) (model.Model, error) {
keyLength := int(c.Uint("attention.key_length"))
valueLength := int(c.Uint("attention.value_length"))
kqScale := 1.0 / math.Sqrt(float64(keyLength))
var pre []string
@@ -236,7 +241,7 @@ func New(c fs.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
@@ -279,6 +284,15 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
func (m *Model) Validate() error {
for _, layer := range m.Layers {
if layer.Attention != nil && (layer.Attention.KB == nil || layer.Attention.VB == nil) {
return ErrOldModelFormat
}
}
return nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))

View File

@@ -0,0 +1,73 @@
package glm4moelite
import (
"testing"
"github.com/ollama/ollama/ml/nn"
)
func TestValidate(t *testing.T) {
tests := []struct {
name string
model *Model
wantErr bool
}{
{
name: "valid model with KB and VB",
model: &Model{
Layers: []Layer{
{Attention: &Attention{KB: &nn.Linear{}, VB: &nn.Linear{}}},
},
},
wantErr: false,
},
{
name: "missing KB",
model: &Model{
Layers: []Layer{
{Attention: &Attention{VB: &nn.Linear{}}},
},
},
wantErr: true,
},
{
name: "missing VB",
model: &Model{
Layers: []Layer{
{Attention: &Attention{KB: &nn.Linear{}}},
},
},
wantErr: true,
},
{
name: "missing both KB and VB",
model: &Model{
Layers: []Layer{
{Attention: &Attention{}},
},
},
wantErr: true,
},
{
name: "nil Attention is ok",
model: &Model{
Layers: []Layer{
{Attention: nil},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.model.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr && err != ErrOldModelFormat {
t.Errorf("Validate() error = %v, want %v", err, ErrOldModelFormat)
}
})
}
}

View File

@@ -14,8 +14,8 @@
VOL_NAME=${VOL_NAME:-"Ollama"}
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
export CGO_CFLAGS="-mmacosx-version-min=14.0"
export CGO_CXXFLAGS="-mmacosx-version-min=14.0"
export CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=14.0"
export CGO_LDFLAGS="-mmacosx-version-min=14.0"
set -e

View File

@@ -56,6 +56,12 @@ function checkEnv {
$script:DIST_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
$env:CGO_ENABLED="1"
if (-not $env:CGO_CFLAGS) {
$env:CGO_CFLAGS = "-O3"
}
if (-not $env:CGO_CXXFLAGS) {
$env:CGO_CXXFLAGS = "-O3"
}
Write-Output "Checking version"
if (!$env:VERSION) {
$data=(git describe --tags --first-parent --abbrev=7 --long --dirty --always)

View File

@@ -95,6 +95,13 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
newType = fsggml.TensorTypeQ8_0
}
} else if strings.Contains(name, "attn_k_b.weight") ||
strings.Contains(name, "attn_v_b.weight") ||
strings.Contains(name, "attn_kv_a_mqa.weight") ||
strings.Contains(name, "attn_q_a.weight") ||
strings.Contains(name, "attn_q_b.weight") {
// MLA tensors need higher precision to avoid quality degradation
newType = fsggml.TensorTypeQ8_0
} else if strings.Contains(name, "ffn_down") {
iLayer := qs.iFfnDown
n_layer := qs.nFfnDown

View File

@@ -2508,8 +2508,14 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
return
}
// Set headers for streaming response
c.Header("Content-Type", "application/x-ndjson")
// Check streaming preference
isStreaming := req.Stream == nil || *req.Stream
contentType := "application/x-ndjson"
if !isStreaming {
contentType = "application/json; charset=utf-8"
}
c.Header("Content-Type", contentType)
// Get seed from options if provided
var seed int64
@@ -2530,6 +2536,8 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
}
var streamStarted bool
var finalResponse api.GenerateResponse
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: req.Width,
@@ -2560,6 +2568,11 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if !isStreaming {
finalResponse = res
return
}
data, _ := json.Marshal(res)
c.Writer.Write(append(data, '\n'))
c.Writer.Flush()
@@ -2569,5 +2582,10 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
if !streamStarted {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
if !isStreaming {
c.JSON(http.StatusOK, finalResponse)
}
}

View File

@@ -19,7 +19,9 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
@@ -71,6 +73,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
return
}
func (mockRunner) Ping(_ context.Context) error { return nil }
func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
return mock, nil
@@ -2347,3 +2351,92 @@ func TestGenerateWithImages(t *testing.T) {
}
})
}
// TestImageGenerateStreamFalse tests that image generation respects stream=false
// and returns a single JSON response instead of streaming ndjson.
func TestImageGenerateStreamFalse(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
mock := mockRunner{}
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
fn(llm.CompletionResponse{Step: 2, TotalSteps: 3, Done: false})
fn(llm.CompletionResponse{Step: 3, TotalSteps: 3, Done: true, DoneReason: llm.DoneReasonStop, Image: "base64image"})
return nil
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
"": {
llama: &mock,
Options: &opts,
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
// Create model manifest with image capability
n := model.ParseName("test-image")
cfg := model.ConfigV2{Capabilities: []string{"image"}}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(&cfg); err != nil {
t.Fatal(err)
}
configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
if err := manifest.WriteManifest(n, configLayer, nil); err != nil {
t.Fatal(err)
}
streamFalse := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image",
Prompt: "test prompt",
Stream: &streamFalse,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
if ct := w.Header().Get("Content-Type"); ct != "application/json; charset=utf-8" {
t.Errorf("expected Content-Type 'application/json; charset=utf-8', got %q", ct)
}
body := w.Body.String()
lines := strings.Split(strings.TrimSpace(body), "\n")
if len(lines) != 1 {
t.Errorf("expected 1 response line, got %d:\n%s", len(lines), body)
}
var resp api.GenerateResponse
if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp.Image != "base64image" {
t.Errorf("expected image 'base64image', got %q", resp.Image)
}
if !resp.Done {
t.Errorf("expected done=true")
}
}

View File

@@ -11,6 +11,8 @@ import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/progress"
@@ -209,10 +211,23 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
return fmt.Errorf("invalid model name: %s", modelName)
}
// TODO: find a better way to detect image input support
// For now, hardcode Flux2KleinPipeline as supporting vision (image input)
caps := capabilities
modelIndex := filepath.Join(opts.ModelDir, "model_index.json")
if data, err := os.ReadFile(modelIndex); err == nil {
var cfg struct {
ClassName string `json:"_class_name"`
}
if json.Unmarshal(data, &cfg) == nil && cfg.ClassName == "Flux2KleinPipeline" {
caps = append(caps, "vision")
}
}
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: capabilities,
Capabilities: caps,
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)

View File

@@ -532,8 +532,10 @@ func extractFileData(input string) (string, []api.ImageData, error) {
var imgs []api.ImageData
for _, fp := range filePaths {
// Normalize escaped spaces
// Normalize shell escapes
nfp := strings.ReplaceAll(fp, "\\ ", " ")
nfp = strings.ReplaceAll(nfp, "\\(", "(")
nfp = strings.ReplaceAll(nfp, "\\)", ")")
nfp = strings.ReplaceAll(nfp, "%20", " ")
data, err := getImageData(nfp)

View File

@@ -7,6 +7,8 @@ import (
"encoding/base64"
"fmt"
"image"
"image/color"
"image/draw"
_ "image/jpeg"
"image/png"
"os"
@@ -111,6 +113,7 @@ func clampF(v, min, max float32) float32 {
}
// DecodeImage decodes image bytes with EXIF orientation applied.
// Transparent images are composited onto a white background.
func DecodeImage(data []byte) (image.Image, error) {
orientation := readJPEGOrientation(data)
@@ -119,9 +122,33 @@ func DecodeImage(data []byte) (image.Image, error) {
return nil, err
}
img = flattenAlpha(img)
return applyOrientation(img, orientation), nil
}
// flattenAlpha composites an image onto a white background,
// removing any transparency. This is needed because image
// generation models don't handle alpha channels well.
func flattenAlpha(img image.Image) image.Image {
if _, ok := img.(*image.RGBA); !ok {
if _, ok := img.(*image.NRGBA); !ok {
// No alpha channel, return as-is
return img
}
}
bounds := img.Bounds()
dst := image.NewRGBA(bounds)
// Fill with white background
draw.Draw(dst, bounds, &image.Uniform{color.White}, image.Point{}, draw.Src)
// Composite the image on top
draw.Draw(dst, bounds, img, bounds.Min, draw.Over)
return dst
}
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
// Returns 1 (normal) for non-JPEG or if orientation not found.
func readJPEGOrientation(data []byte) int {