Compare commits

...

21 Commits

Author SHA1 Message Date
Parth Sareen
06edabdde1 cmd/config: install web search plugin to user-level extensions dir (#14362) 2026-02-22 02:17:03 -08:00
Jeffrey Morgan
8b4e5a82a8 mlx: remove noisy error output from dynamic library loading (#14346)
The recent change in #14322 added tryLoadByName() which attempts to
load libmlxc.dylib via rpath before searching directories. This is an
optimization for Homebrew installations where rpath is correctly set.

However, when rpath isn't set (which is the common case for app bundle
installations), dlopen fails and the CHECK macro prints an error to
stderr:

  ERROR - dynamic.c:21 - CHECK failed: handle->ctx != NULL

This error is misleading because it's an expected failure path - the
code correctly falls back to searching the executable directory and
loads the library successfully. The error message causes user confusion
and makes it appear that something is broken.

Replace the CHECK macro with a simple return code so the C code fails
silently. The Go code already handles error logging appropriately:
tryLoadByName() fails silently (intentional fallback), while
tryLoadFromDir() logs via slog.Error() when explicit path loading fails.
2026-02-20 23:46:07 -08:00
Parth Sareen
3445223311 cmd: openclaw onboarding (#14344) 2026-02-20 19:08:38 -08:00
Jeffrey Morgan
fa6c0127e6 app: expose server's default context length to UI (#14037)
Parse the default_num_ctx from the server's "vram-based default context"
log line and expose it through the inference compute API. This eliminates
duplicate VRAM tier calculation logic in the frontend.

- Add InferenceInfo struct with Computes and DefaultContextLength
- Rename GetInferenceComputer to GetInferenceInfo
- Handle missing default context line gracefully (older servers)
- Add DefaultContextLength to InferenceComputeResponse
- Update Settings UI to use server's default, disable slider while loading
- Add disabled prop to Slider component (grays out + hides handle)
- Migrate existing users with context_length=4096 to 0 (auto mode)
2026-02-20 18:56:30 -08:00
Patrick Devine
97323d1c68 consolidate the tokenizer (#14327)
This change adds a new x/tokenizer package which includes:
  * New BPE and SentencePiece tokenizers
  * Removing the dependency on the imagegen tokenizers
  * Fixes to multibyte decoding in the pipeline
  * Various correctness and benchmark tests

Not included in this PR is the WordPiece tokenizer for BERT models which will be
added when we add embedding models. The imagegen tokenizers will also be removed in
a follow-up PR.
2026-02-19 15:55:45 -08:00
natl-set
458dd1b9d9 mlx: try loading library via rpath before searching directories (#14322)
The existing code manually searches directories for libmlxc.* and passes
full paths to dlopen, bypassing the binary's rpath. This means MLX
libraries installed via package managers (e.g., Homebrew) aren't found
even when rpath is correctly set at link time.

This change adds a fallback that tries loading via rpath first (using
just the library name), before falling back to the existing directory
search. This follows standard Unix/macOS conventions and works with any
installation that sets rpath.

Fixes library loading on macOS with Homebrew-installed mlx-c without
requiring OLLAMA_LIBRARY_PATH environment variable.

Co-authored-by: Natl <nat@MacBook-Pro.local>
2026-02-19 10:55:02 -08:00
Bruce MacDonald
9d02d1d767 install: prevent partial download script execution (#14311)
Wrap script in main function so that a truncated partial download doesn't end up executing half a script.
2026-02-18 18:32:45 -08:00
Bruce MacDonald
1a636fb47a cmd: set codex env vars on launch and handle zstd request bodies (#14122)
The Codex runner was not setting OPENAI_BASE_URL or OPENAI_API_KEY, this prevents Codex from sending requests to api.openai.com instead of the local Ollama server. This mirrors the approach used by the Claude runner.

Codex v0.98.0 sends zstd-compressed request bodies to the /v1/responses endpoint. Add decompression support in ResponsesMiddleware with an 8MB max decompressed size limit to prevent resource exhaustion.
2026-02-18 17:19:36 -08:00
Patrick Devine
0759fface9 Revert "chore: update mlx-c bindings to 0.5.0 (#14303)" (#14316)
This reverts commit f01a9a7859.
2026-02-18 17:01:25 -08:00
Parth Sareen
325b72bc31 cmd/tui: default to single-select for editor integrations (#14302) 2026-02-17 18:17:27 -08:00
Patrick Devine
f01a9a7859 chore: update mlx-c bindings to 0.5.0 (#14303) 2026-02-17 16:48:16 -08:00
Patrick Devine
9aefd2dfee model: add qwen3 support to mlxrunner (#14293) 2026-02-17 13:58:49 -08:00
Patrick Devine
d07e4a1dd3 bugfix: better mlx model scheduling (#14290)
This fixes a bug with current MLX based models which don't get loaded/unloaded correctly. The first model currently gets loaded and then subsequent model starts get shunted to the first runner which results in the wrong model being run.
2026-02-17 13:57:05 -08:00
Parth Sareen
8a257ec00a docs: make integrations more discoverable (#14301)
* docs: add Pi integration page

* docs: flatten integration sidebar with expanded subheadings

* docs: add OpenClaw and Claude Code to quickstart
2026-02-17 13:27:25 -08:00
Parth Sareen
2f4de1acf7 cmd: ollama launch always show model picker (#14299) 2026-02-17 12:02:14 -08:00
Parth Sareen
ec95c45f70 cmd/config: ollama launch cline CLI (#14294) 2026-02-17 11:37:53 -08:00
Patrick Devine
3a88f7eb20 bugfix: add missing linear layer factory (#14289) 2026-02-16 17:22:20 -08:00
Patrick Devine
0d5da826d4 bugfix: display the parameter count correctly in mlx for ollama show (#14285) 2026-02-16 13:03:34 -08:00
Patrick Devine
9b795698b8 model: add llama3 architecture to mlxrunner (#14277) 2026-02-15 23:06:28 -08:00
Patrick Devine
041fb77639 model: add gemma3 to the mlxrunner (#14276)
This change adds the gemma3 model to the mlxrunner and simplifies some of the quantization
code for loading weights.
2026-02-15 22:47:59 -08:00
Saumil Shah
8224cce583 readme: update download link for macOS (#1) (#14271) 2026-02-15 15:25:15 -08:00
72 changed files with 7078 additions and 420 deletions

View File

@@ -16,7 +16,7 @@ Start building with open models.
curl -fsSL https://ollama.com/install.sh | sh
```
or [download manually](http://localhost:8080/download/Ollama.dmg)
or [download manually](https://ollama.com/download/Ollama.dmg)
### Windows

View File

@@ -41,6 +41,11 @@ type InferenceCompute struct {
VRAM string
}
type InferenceInfo struct {
Computes []InferenceCompute
DefaultContextLength int
}
func New(s *store.Store, devMode bool) *Server {
p := resolvePath("ollama")
return &Server{store: s, bin: p, dev: devMode}
@@ -272,9 +277,12 @@ func openRotatingLog() (io.WriteCloser, error) {
// Attempt to retrieve inference compute information from the server
// log. Set ctx to timeout to control how long to wait for the logs to appear
func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
inference := []InferenceCompute{}
marker := regexp.MustCompile(`inference compute.*library=`)
func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
info := &InferenceInfo{}
computeMarker := regexp.MustCompile(`inference compute.*library=`)
defaultCtxMarker := regexp.MustCompile(`vram-based default context`)
defaultCtxRegex := regexp.MustCompile(`default_num_ctx=(\d+)`)
q := `inference compute.*%s=["]([^"]*)["]`
nq := `inference compute.*%s=(\S+)\s`
type regex struct {
@@ -340,8 +348,8 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
match := marker.FindStringSubmatch(line)
if len(match) > 0 {
// Check for inference compute lines
if computeMarker.MatchString(line) {
ic := InferenceCompute{
Library: get("library", line),
Variant: get("variant", line),
@@ -352,12 +360,25 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
}
slog.Info("Matched", "inference compute", ic)
inference = append(inference, ic)
} else {
// Break out on first non matching line after we start matching
if len(inference) > 0 {
return inference, nil
info.Computes = append(info.Computes, ic)
continue
}
// Check for default context length line
if defaultCtxMarker.MatchString(line) {
match := defaultCtxRegex.FindStringSubmatch(line)
if len(match) > 1 {
numCtx, err := strconv.Atoi(match[1])
if err == nil {
info.DefaultContextLength = numCtx
slog.Info("Matched default context length", "default_num_ctx", numCtx)
}
}
return info, nil
}
// If we've found compute info but hit a non-matching line, return what we have
// This handles older server versions that don't log the default context line
if len(info.Computes) > 0 {
return info, nil
}
}
time.Sleep(100 * time.Millisecond)

View File

@@ -205,44 +205,50 @@ func TestServerCmdCloudSettingEnv(t *testing.T) {
}
}
func TestGetInferenceComputer(t *testing.T) {
func TestGetInferenceInfo(t *testing.T) {
tests := []struct {
name string
log string
exp []InferenceCompute
name string
log string
expComputes []InferenceCompute
expDefaultCtxLen int
}{
{
name: "metal",
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
time=2025-06-30T09:23:07.417-07:00 level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="96.0 GiB" default_num_ctx=262144
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
`,
exp: []InferenceCompute{{
expComputes: []InferenceCompute{{
Library: "metal",
Driver: "0.0",
VRAM: "96.0 GiB",
}},
expDefaultCtxLen: 262144,
},
{
name: "cpu",
log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered"
time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB"
time=2025-07-01T17:59:51.471Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="31.3 GiB" default_num_ctx=32768
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
exp: []InferenceCompute{{
expComputes: []InferenceCompute{{
Library: "cpu",
Driver: "0.0",
VRAM: "31.3 GiB",
}},
expDefaultCtxLen: 32768,
},
{
name: "cuda1",
log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu"
releasing cuda driver library
time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB"
time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="3.9 GiB" default_num_ctx=4096
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
exp: []InferenceCompute{{
expComputes: []InferenceCompute{{
Library: "cuda",
Variant: "v12",
Compute: "6.1",
@@ -250,6 +256,7 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
Name: "NVIDIA GeForce GT 1030",
VRAM: "3.9 GiB",
}},
expDefaultCtxLen: 4096,
},
{
name: "frank",
@@ -257,9 +264,10 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
releasing cuda driver library
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB"
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB"
time=2025-07-01T19:36:13.316Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="26.6 GiB" default_num_ctx=32768
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
exp: []InferenceCompute{
expComputes: []InferenceCompute{
{
Library: "cuda",
Variant: "v12",
@@ -276,6 +284,20 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
VRAM: "16.0 GiB",
},
},
expDefaultCtxLen: 32768,
},
{
name: "missing_default_context",
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
`,
expComputes: []InferenceCompute{{
Library: "metal",
Driver: "0.0",
VRAM: "96.0 GiB",
}},
expDefaultCtxLen: 0, // No default context line, should return 0
},
}
for _, tt := range tests {
@@ -288,18 +310,21 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
}
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
defer cancel()
ics, err := GetInferenceComputer(ctx)
info, err := GetInferenceInfo(ctx)
if err != nil {
t.Fatalf(" failed to get inference compute: %v", err)
t.Fatalf("failed to get inference info: %v", err)
}
if !reflect.DeepEqual(ics, tt.exp) {
t.Fatalf("got:\n%#v\nwant:\n%#v", ics, tt.exp)
if !reflect.DeepEqual(info.Computes, tt.expComputes) {
t.Fatalf("computes mismatch\ngot:\n%#v\nwant:\n%#v", info.Computes, tt.expComputes)
}
if info.DefaultContextLength != tt.expDefaultCtxLen {
t.Fatalf("default context length mismatch: got %d, want %d", info.DefaultContextLength, tt.expDefaultCtxLen)
}
})
}
}
func TestGetInferenceComputerTimeout(t *testing.T) {
func TestGetInferenceInfoTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
defer cancel()
tmpDir := t.TempDir()
@@ -308,7 +333,7 @@ func TestGetInferenceComputerTimeout(t *testing.T) {
if err != nil {
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
}
_, err = GetInferenceComputer(ctx)
_, err = GetInferenceInfo(ctx)
if err == nil {
t.Fatal("expected timeout")
}

View File

@@ -14,7 +14,7 @@ import (
// currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations.
const currentSchemaVersion = 13
const currentSchemaVersion = 14
// database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access:
@@ -73,7 +73,7 @@ func (db *database) init() error {
agent BOOLEAN NOT NULL DEFAULT 0,
tools BOOLEAN NOT NULL DEFAULT 0,
working_dir TEXT NOT NULL DEFAULT '',
context_length INTEGER NOT NULL DEFAULT 4096,
context_length INTEGER NOT NULL DEFAULT 0,
window_width INTEGER NOT NULL DEFAULT 0,
window_height INTEGER NOT NULL DEFAULT 0,
config_migrated BOOLEAN NOT NULL DEFAULT 0,
@@ -251,6 +251,12 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v12 to v13: %w", err)
}
version = 13
case 13:
// change default context_length from 4096 to 0 (VRAM-based tiered defaults)
if err := db.migrateV13ToV14(); err != nil {
return fmt.Errorf("migrate v13 to v14: %w", err)
}
version = 14
default:
// If we have a version we don't recognize, just set it to current
// This might happen during development
@@ -474,6 +480,22 @@ func (db *database) migrateV12ToV13() error {
return nil
}
// migrateV13ToV14 changes the default context_length from 4096 to 0.
// When context_length is 0, the ollama server uses VRAM-based tiered defaults.
func (db *database) migrateV13ToV14() error {
_, err := db.conn.Exec(`UPDATE settings SET context_length = 0 WHERE context_length = 4096`)
if err != nil {
return fmt.Errorf("update context_length default: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 14`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(`

View File

@@ -98,6 +98,43 @@ func TestSchemaMigrations(t *testing.T) {
})
}
func TestMigrationV13ToV14ContextLength(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
db, err := newDatabase(dbPath)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
_, err = db.conn.Exec("UPDATE settings SET context_length = 4096, schema_version = 13")
if err != nil {
t.Fatalf("failed to seed v13 settings row: %v", err)
}
if err := db.migrate(); err != nil {
t.Fatalf("migration from v13 to v14 failed: %v", err)
}
var contextLength int
if err := db.conn.QueryRow("SELECT context_length FROM settings").Scan(&contextLength); err != nil {
t.Fatalf("failed to read context_length: %v", err)
}
if contextLength != 0 {
t.Fatalf("expected context_length to migrate to 0, got %d", contextLength)
}
version, err := db.getSchemaVersion()
if err != nil {
t.Fatalf("failed to get schema version: %v", err)
}
if version != currentSchemaVersion {
t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version)
}
}
func TestChatDeletionWithCascade(t *testing.T) {
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
tmpDir := t.TempDir()

View File

@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS settings (
agent BOOLEAN NOT NULL DEFAULT 0,
tools BOOLEAN NOT NULL DEFAULT 0,
working_dir TEXT NOT NULL DEFAULT '',
context_length INTEGER NOT NULL DEFAULT 4096,
context_length INTEGER NOT NULL DEFAULT 0,
window_width INTEGER NOT NULL DEFAULT 0,
window_height INTEGER NOT NULL DEFAULT 0,
config_migrated BOOLEAN NOT NULL DEFAULT 0,

View File

@@ -289,10 +289,12 @@ export class InferenceCompute {
}
export class InferenceComputeResponse {
inferenceComputes: InferenceCompute[];
defaultContextLength: number;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute);
this.defaultContextLength = source["defaultContextLength"];
}
convertValues(a: any, classs: any, asMap: boolean = false): any {

View File

@@ -4,7 +4,6 @@ import {
ChatEvent,
DownloadEvent,
ErrorEvent,
InferenceCompute,
InferenceComputeResponse,
ModelCapabilitiesResponse,
Model,
@@ -407,7 +406,7 @@ export async function* pullModel(
}
}
export async function getInferenceCompute(): Promise<InferenceCompute[]> {
export async function getInferenceCompute(): Promise<InferenceComputeResponse> {
const response = await fetch(`${API_BASE}/api/v1/inference-compute`);
if (!response.ok) {
throw new Error(
@@ -416,8 +415,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
}
const data = await response.json();
const inferenceComputeResponse = new InferenceComputeResponse(data);
return inferenceComputeResponse.inferenceComputes || [];
return new InferenceComputeResponse(data);
}
export async function fetchHealth(): Promise<boolean> {

View File

@@ -26,6 +26,7 @@ import {
type CloudStatusResponse,
updateCloudSetting,
updateSettings,
getInferenceCompute,
} from "@/api";
function AnimatedDots() {
@@ -77,6 +78,13 @@ export default function Settings() {
const settings = settingsData?.settings || null;
const { data: inferenceComputeResponse } = useQuery({
queryKey: ["inferenceCompute"],
queryFn: getInferenceCompute,
});
const defaultContextLength = inferenceComputeResponse?.defaultContextLength;
const updateSettingsMutation = useMutation({
mutationFn: updateSettings,
onSuccess: () => {
@@ -204,7 +212,7 @@ export default function Settings() {
Models: "",
Agent: false,
Tools: false,
ContextLength: 4096,
ContextLength: 0,
});
updateSettingsMutation.mutate(defaultSettings);
}
@@ -507,13 +515,11 @@ export default function Settings() {
</Description>
<div className="mt-3">
<Slider
value={(() => {
// Otherwise use the settings value
return settings.ContextLength || 4096;
})()}
value={settings.ContextLength || defaultContextLength || 0}
onChange={(value) => {
handleChange("ContextLength", value);
}}
disabled={!defaultContextLength}
options={[
{ value: 4096, label: "4k" },
{ value: 8192, label: "8k" },

View File

@@ -6,10 +6,11 @@ export interface SliderProps {
value?: number;
onChange?: (value: number) => void;
className?: string;
disabled?: boolean;
}
const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
({ label, options, value = 0, onChange }, ref) => {
({ label, options, value = 0, onChange, disabled = false }, ref) => {
const [selectedValue, setSelectedValue] = React.useState(value);
const [isDragging, setIsDragging] = React.useState(false);
const containerRef = React.useRef<HTMLDivElement>(null);
@@ -20,6 +21,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
}, [value]);
const handleClick = (optionValue: number) => {
if (disabled) return;
setSelectedValue(optionValue);
onChange?.(optionValue);
};
@@ -39,6 +41,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
};
const handleMouseDown = (e: React.MouseEvent) => {
if (disabled) return;
setIsDragging(true);
e.preventDefault();
};
@@ -77,7 +80,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
}
return (
<div className="space-y-2" ref={ref}>
<div className={`space-y-2 ${disabled ? "opacity-50" : ""}`} ref={ref}>
{label && <label className="text-sm font-medium">{label}</label>}
<div className="relative">
<div className="absolute top-[9px] left-2 right-2 h-1 bg-neutral-200 dark:bg-neutral-700 pointer-events-none rounded-full" />
@@ -88,10 +91,11 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
<button
onClick={() => handleClick(option.value)}
onMouseDown={handleMouseDown}
className="relative px-3 py-6 -mx-3 -my-6 z-10 cursor-pointer"
disabled={disabled}
className={`relative px-3 py-6 -mx-3 -my-6 z-10 ${disabled ? "cursor-not-allowed" : "cursor-pointer"}`}
>
<div className="relative w-5 h-5 flex items-center justify-center">
{selectedValue === option.value && (
{selectedValue === option.value && !disabled && (
<div className="w-4 h-4 bg-white dark:bg-white border border-neutral-400 dark:border-neutral-500 rounded-full cursor-grab active:cursor-grabbing" />
)}
</div>

View File

@@ -28,12 +28,14 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
currentChatId && currentChatId !== "new" ? currentChatId : "",
);
const { data: inferenceComputes = [] } = useQuery({
queryKey: ["inference-compute"],
const { data: inferenceComputeResponse } = useQuery({
queryKey: ["inferenceCompute"],
queryFn: getInferenceCompute,
enabled: !settings.selectedModel, // Only fetch if no model is selected
});
const inferenceComputes = inferenceComputeResponse?.inferenceComputes || [];
const totalVRAM = useMemo(
() => getTotalVRAM(inferenceComputes),
[inferenceComputes],

View File

@@ -45,7 +45,8 @@ type InferenceCompute struct {
}
type InferenceComputeResponse struct {
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
DefaultContextLength int `json:"defaultContextLength"`
}
type ModelCapabilitiesResponse struct {

View File

@@ -1420,11 +1420,6 @@ func (s *Server) getSettings(w http.ResponseWriter, r *http.Request) error {
settings.Models = envconfig.Models()
}
// set default context length if not set
if settings.ContextLength == 0 {
settings.ContextLength = 4096
}
// Include current runtime settings
settings.Agent = s.Agent
settings.Tools = s.Tools
@@ -1500,14 +1495,14 @@ func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
defer cancel()
serverInferenceComputes, err := server.GetInferenceComputer(ctx)
info, err := server.GetInferenceInfo(ctx)
if err != nil {
s.log().Error("failed to get inference compute", "error", err)
return fmt.Errorf("failed to get inference compute: %w", err)
s.log().Error("failed to get inference info", "error", err)
return fmt.Errorf("failed to get inference info: %w", err)
}
inferenceComputes := make([]responses.InferenceCompute, len(serverInferenceComputes))
for i, ic := range serverInferenceComputes {
inferenceComputes := make([]responses.InferenceCompute, len(info.Computes))
for i, ic := range info.Computes {
inferenceComputes[i] = responses.InferenceCompute{
Library: ic.Library,
Variant: ic.Variant,
@@ -1519,7 +1514,8 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
}
response := responses.InferenceComputeResponse{
InferenceComputes: inferenceComputes,
InferenceComputes: inferenceComputes,
DefaultContextLength: info.DefaultContextLength,
}
w.Header().Set("Content-Type", "application/json")

View File

@@ -57,9 +57,9 @@ import (
func init() {
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) {
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems)
result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
}
@@ -182,6 +182,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
mfConfig.Renderer = cmd.Args
}
}
@@ -1897,9 +1901,9 @@ func runInteractiveTUI(cmd *cobra.Command) {
}
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem) (string, error) {
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems)
result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled
}
@@ -1952,6 +1956,10 @@ func runInteractiveTUI(cmd *cobra.Command) {
}
launchIntegration := func(name string) bool {
if err := config.EnsureInstalled(name); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return true
}
// If not configured or model no longer exists, prompt for model selection
configuredModel := config.IntegrationModel(name)
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {

View File

@@ -126,7 +126,7 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force {
primary, err := DefaultSingleSelector("Select model:", items)
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
if err != nil {
return nil, false, err
}

123
cmd/config/cline.go Normal file
View File

@@ -0,0 +1,123 @@
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/ollama/ollama/envconfig"
)
// Cline implements Runner and Editor for the Cline CLI integration
type Cline struct{}
func (c *Cline) String() string { return "Cline" }
func (c *Cline) Run(model string, args []string) error {
if _, err := exec.LookPath("cline"); err != nil {
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
}
models := []string{model}
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
return selectModels(context.Background(), "cline", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("cline", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (c *Cline) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
p := filepath.Join(home, ".cline", "data", "globalState.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (c *Cline) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
if err := json.Unmarshal(data, &config); err != nil {
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
}
}
// Set Ollama as the provider for both act and plan modes
baseURL := envconfig.Host().String()
config["ollamaBaseUrl"] = baseURL
config["actModeApiProvider"] = "ollama"
config["actModeOllamaModelId"] = models[0]
config["actModeOllamaBaseUrl"] = baseURL
config["planModeApiProvider"] = "ollama"
config["planModeOllamaModelId"] = models[0]
config["planModeOllamaBaseUrl"] = baseURL
config["welcomeViewCompleted"] = true
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return writeWithBackup(configPath, data)
}
func (c *Cline) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
if err != nil {
return nil
}
if config["actModeApiProvider"] != "ollama" {
return nil
}
modelID, _ := config["actModeOllamaModelId"].(string)
if modelID == "" {
return nil
}
return []string{modelID}
}

204
cmd/config/cline_test.go Normal file
View File

@@ -0,0 +1,204 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestClineIntegration(t *testing.T) {
c := &Cline{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Cline" {
t.Errorf("String() = %q, want %q", got, "Cline")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = c
})
}
func TestClineEdit(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".cline", "data")
configPath := filepath.Join(configDir, "globalState.json")
readConfig := func() map[string]any {
data, _ := os.ReadFile(configPath)
var config map[string]any
json.Unmarshal(data, &config)
return config
}
t.Run("creates config from scratch", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeApiProvider"] != "ollama" {
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
}
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
}
if config["planModeApiProvider"] != "ollama" {
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
}
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
}
if config["welcomeViewCompleted"] != true {
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
}
})
t.Run("preserves existing fields", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
os.MkdirAll(configDir, 0o755)
existing := map[string]any{
"remoteRulesToggles": map[string]any{},
"remoteWorkflowToggles": map[string]any{},
"customSetting": "keep-me",
}
data, _ := json.Marshal(existing)
os.WriteFile(configPath, data, 0o644)
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["customSetting"] != "keep-me" {
t.Errorf("customSetting was not preserved")
}
if config["actModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
}
})
t.Run("updates model on re-edit", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
t.Fatal(err)
}
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
}
if config["planModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
}
})
t.Run("empty models is no-op", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit(nil); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
t.Error("expected no config file to be created for empty models")
}
})
t.Run("uses first model as primary", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
}
})
}
func TestClineModels(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".cline", "data")
configPath := filepath.Join(configDir, "globalState.json")
t.Run("returns nil when no config", func(t *testing.T) {
if models := c.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
os.MkdirAll(configDir, 0o755)
config := map[string]any{
"actModeApiProvider": "anthropic",
"actModeOllamaModelId": "some-model",
}
data, _ := json.Marshal(config)
os.WriteFile(configPath, data, 0o644)
if models := c.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns model when ollama is configured", func(t *testing.T) {
os.MkdirAll(configDir, 0o755)
config := map[string]any{
"actModeApiProvider": "ollama",
"actModeOllamaModelId": "kimi-k2.5:cloud",
}
data, _ := json.Marshal(config)
os.WriteFile(configPath, data, 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
}
})
}
func TestClinePaths(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns nil when no config exists", func(t *testing.T) {
if paths := c.Paths(); paths != nil {
t.Errorf("Paths() = %v, want nil", paths)
}
})
t.Run("returns path when config exists", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".cline", "data")
os.MkdirAll(configDir, 0o755)
configPath := filepath.Join(configDir, "globalState.json")
os.WriteFile(configPath, []byte("{}"), 0o644)
paths := c.Paths()
if len(paths) != 1 || paths[0] != configPath {
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
}
})
}

View File

@@ -6,6 +6,7 @@ import (
"os/exec"
"strings"
"github.com/ollama/ollama/envconfig"
"golang.org/x/mod/semver"
)
@@ -32,6 +33,10 @@ func (c *Codex) Run(model string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
"OPENAI_API_KEY=ollama",
)
return cmd.Run()
}

View File

@@ -15,8 +15,9 @@ import (
)
type integration struct {
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
Onboarded bool `json:"onboarded,omitempty"`
}
type config struct {
@@ -139,34 +140,54 @@ func SaveIntegration(appName string, models []string) error {
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
var aliases map[string]string
if existing != nil && existing.Aliases != nil {
var onboarded bool
if existing != nil {
aliases = existing.Aliases
onboarded = existing.Onboarded
}
cfg.Integrations[key] = &integration{
Models: models,
Aliases: aliases,
Models: models,
Aliases: aliases,
Onboarded: onboarded,
}
return save(cfg)
}
// integrationOnboarded marks an integration as onboarded in ollama's config.
func integrationOnboarded(appName string) error {
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
existing.Onboarded = true
cfg.Integrations[key] = existing
return save(cfg)
}
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
func IntegrationModel(appName string) string {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
integrationConfig, err := loadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 {
return ""
}
return ic.Models[0]
return integrationConfig.Models[0]
}
// IntegrationModels returns all configured models for an integration, or nil.
func IntegrationModels(appName string) []string {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
integrationConfig, err := loadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 {
return nil
}
return ic.Models
return integrationConfig.Models
}
// LastModel returns the last model that was run, or empty string if none.
@@ -234,12 +255,12 @@ func loadIntegration(appName string) (*integration, error) {
return nil, err
}
ic, ok := cfg.Integrations[strings.ToLower(appName)]
integrationConfig, ok := cfg.Integrations[strings.ToLower(appName)]
if !ok {
return nil, os.ErrNotExist
}
return ic, nil
return integrationConfig, nil
}
func saveAliases(appName string, aliases map[string]string) error {
@@ -272,8 +293,8 @@ func listIntegrations() ([]integration, error) {
}
result := make([]integration, 0, len(cfg.Integrations))
for _, ic := range cfg.Integrations {
result = append(result, *ic)
for _, integrationConfig := range cfg.Integrations {
result = append(result, *integrationConfig)
}
return result, nil

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"maps"
"net/http"
"os"
"os/exec"
@@ -54,6 +53,7 @@ type AliasConfigurer interface {
var integrations = map[string]Runner{
"claude": &Claude{},
"clawdbot": &Openclaw{},
"cline": &Cline{},
"codex": &Codex{},
"moltbot": &Openclaw{},
"droid": &Droid{},
@@ -102,16 +102,17 @@ var recommendedVRAM = map[string]string{
var integrationAliases = map[string]bool{
"clawdbot": true,
"moltbot": true,
"pi": true,
}
// integrationInstallHints maps integration names to install URLs.
var integrationInstallHints = map[string]string{
"claude": "https://code.claude.com/docs/en/quickstart",
"cline": "https://cline.bot/cli",
"openclaw": "https://docs.openclaw.ai",
"codex": "https://developers.openai.com/codex/cli/",
"droid": "https://docs.factory.ai/cli/getting-started/quickstart",
"opencode": "https://opencode.ai",
"pi": "https://github.com/badlogic/pi-mono",
}
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
@@ -129,13 +130,21 @@ type IntegrationInfo struct {
// integrationDescriptions maps integration names to short descriptions.
var integrationDescriptions = map[string]string{
"claude": "Anthropic's coding tool with subagents",
"cline": "Autonomous coding agent with parallel execution",
"codex": "OpenAI's open-source coding agent",
"openclaw": "Personal AI with 100+ skills",
"droid": "Factory's coding agent across terminal and IDEs",
"opencode": "Anomaly's open-source coding agent",
"pi": "Minimal AI agent toolkit with plugin support",
}
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name.
// integrationOrder defines a custom display order for integrations.
// Integrations listed here are placed at the end in the given order;
// all others appear first, sorted alphabetically.
var integrationOrder = []string{"opencode", "droid", "pi", "cline"}
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name
// with integrationOrder entries placed at the end.
func ListIntegrationInfos() []IntegrationInfo {
var result []IntegrationInfo
for name, r := range integrations {
@@ -148,7 +157,26 @@ func ListIntegrationInfos() []IntegrationInfo {
Description: integrationDescriptions[name],
})
}
orderRank := make(map[string]int, len(integrationOrder))
for i, name := range integrationOrder {
orderRank[name] = i + 1 // 1-indexed so 0 means "not in the list"
}
slices.SortFunc(result, func(a, b IntegrationInfo) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
// Both have custom order: sort by their rank
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
// Only one has custom order: it goes last
if aRank > 0 {
return 1
}
if bRank > 0 {
return -1
}
// Neither has custom order: alphabetical
return strings.Compare(a.Name, b.Name)
})
return result
@@ -186,14 +214,45 @@ func IsIntegrationInstalled(name string) bool {
case "droid":
_, err := exec.LookPath("droid")
return err == nil
case "cline":
_, err := exec.LookPath("cline")
return err == nil
case "opencode":
_, err := exec.LookPath("opencode")
return err == nil
case "pi":
_, err := exec.LookPath("pi")
return err == nil
default:
return true // Assume installed for unknown integrations
}
}
// AutoInstallable returns true if the integration can be automatically
// installed when not found (e.g. via npm).
func AutoInstallable(name string) bool {
switch strings.ToLower(name) {
case "openclaw", "clawdbot", "moltbot":
return true
default:
return false
}
}
// EnsureInstalled checks if an auto-installable integration is present and
// offers to install it if missing. Returns nil for non-auto-installable
// integrations or when the binary is already on PATH.
func EnsureInstalled(name string) error {
if !AutoInstallable(name) {
return nil
}
if IsIntegrationInstalled(name) {
return nil
}
_, err := ensureOpenclawInstalled()
return err
}
// IsEditorIntegration returns true if the named integration uses multi-model
// selection (implements the Editor interface).
func IsEditorIntegration(name string) bool {
@@ -214,7 +273,8 @@ type ModelItem struct {
}
// SingleSelector is a function type for single item selection.
type SingleSelector func(title string, items []ModelItem) (string, error)
// current is the name of the previously selected item to highlight; empty means no pre-selection.
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
// MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
@@ -257,7 +317,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
}
selected, err := selector("Select model to run:", items)
selected, err := selector("Select model to run:", items, "")
if err != nil {
return "", err
}
@@ -367,13 +427,11 @@ func selectIntegration() (string, error) {
return "", fmt.Errorf("no integrations available")
}
names := slices.Sorted(maps.Keys(integrations))
var items []ModelItem
for _, name := range names {
for name, r := range integrations {
if integrationAliases[name] {
continue
}
r := integrations[name]
description := r.String()
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
@@ -381,7 +439,25 @@ func selectIntegration() (string, error) {
items = append(items, ModelItem{Name: name, Description: description})
}
return DefaultSingleSelector("Select integration:", items)
orderRank := make(map[string]int, len(integrationOrder))
for i, name := range integrationOrder {
orderRank[name] = i + 1
}
slices.SortFunc(items, func(a, b ModelItem) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
if aRank > 0 {
return 1
}
if bRank > 0 {
return -1
}
return strings.Compare(a.Name, b.Name)
})
return DefaultSingleSelector("Select integration:", items, "")
}
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
@@ -439,7 +515,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r)
}
model, err := single(prompt, items)
model, err := single(prompt, items, current)
if err != nil {
return nil, err
}
@@ -812,10 +888,12 @@ Without arguments, this is equivalent to running 'ollama' directly.
Supported integrations:
claude Claude Code
cline Cline
codex Codex
droid Droid
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
pi Pi
Examples:
ollama launch
@@ -873,6 +951,10 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
if err := EnsureInstalled(name); err != nil {
return err
}
if modelFlag != "" && IsCloudModelDisabled(cmd.Context(), modelFlag) {
modelFlag = ""
}
@@ -915,11 +997,9 @@ Examples:
}
// Validate saved model still exists
cloudCleared := false
if model != "" && modelFlag == "" {
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) {
model = ""
cloudCleared = true
} else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
if err := ShowOrPull(cmd.Context(), client, model); err != nil {
@@ -928,18 +1008,16 @@ Examples:
}
}
// If no valid model or --config flag, show picker
if model == "" || configFlag {
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag || cloudCleared)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
// Show picker so user can change model (skip when --model flag provided)
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, modelFlag == "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
// Ensure cloud models are authenticated
if isCloudModel(cmd.Context(), client, model) {
@@ -1001,27 +1079,13 @@ Examples:
return err
}
}
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
savedModels := filterDisabledCloudModels(saved.Models)
if len(savedModels) != len(saved.Models) {
_ = SaveIntegration(name, savedModels)
}
if len(savedModels) == 0 {
// All saved models were cloud — fall through to picker
models, err = selectModels(cmd.Context(), name, "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
} else {
models = savedModels
return runIntegration(name, models[0], passArgs)
}
} else {
current := ""
if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 {
current = saved.Models[0]
}
var err error
models, err = selectModels(cmd.Context(), name, "")
models, err = selectModels(cmd.Context(), name, current)
if errors.Is(err, errCancelled) {
return nil
}

View File

@@ -1248,10 +1248,26 @@ func TestListIntegrationInfos(t *testing.T) {
}
})
t.Run("sorted by name", func(t *testing.T) {
t.Run("sorted with custom order at end", func(t *testing.T) {
// integrationOrder entries (cline, opencode) should appear last, in that order.
// All other entries should be sorted alphabetically before them.
orderRank := make(map[string]int)
for i, name := range integrationOrder {
orderRank[name] = i + 1
}
for i := 1; i < len(infos); i++ {
if infos[i-1].Name >= infos[i].Name {
t.Errorf("not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
switch {
case aRank == 0 && bRank == 0:
if infos[i-1].Name >= infos[i].Name {
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
}
case aRank > 0 && bRank == 0:
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
case aRank > 0 && bRank > 0:
if aRank >= bRank {
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
}
}
}
})

View File

@@ -1,81 +1,287 @@
package config
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
const defaultGatewayPort = 18789
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
var openclawModelShowTimeout = 5 * time.Second
type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
func (c *Openclaw) Run(model string, args []string) error {
bin := "openclaw"
if _, err := exec.LookPath(bin); err != nil {
bin = "clawdbot"
if _, err := exec.LookPath(bin); err != nil {
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
}
}
models := []string{model}
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
models = config.Models
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("openclaw", models, func() ([]string, error) {
return selectModels(context.Background(), "openclaw", "")
})
if errors.Is(err, errCancelled) {
return nil
}
bin, err := ensureOpenclawInstalled()
if err != nil {
return err
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
firstLaunch := true
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
firstLaunch = !integrationConfig.Onboarded
}
if firstLaunch {
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
ok, err := confirmPrompt("I understand the risks. Continue?")
if err != nil {
return err
}
if !ok {
return nil
}
}
if !c.onboarded() {
// Onboarding not completed: run it (model already set via Edit)
// Use "ollama" as gateway token for simple local access
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
cmd := exec.Command(bin, "onboard",
"--non-interactive",
"--accept-risk",
"--auth-choice", "skip",
"--gateway-token", "ollama",
"--install-daemon",
"--skip-channels",
"--skip-skills",
)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
if err := cmd.Run(); err != nil {
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
}
patchDeviceScopes()
// Onboarding overwrites openclaw.json, so re-apply the model config
// that Edit() wrote before Run() was called.
if err := c.Edit([]string{model}); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
}
}
// Onboarding completed: run gateway
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
cmd.Stdin = os.Stdin
if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") {
if ensureWebSearchPlugin() {
registerWebSearchPlugin()
}
}
// Capture output to detect "already running" message
var outputBuf bytes.Buffer
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
if firstLaunch {
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
} else {
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
}
err = cmd.Run()
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
// When extra args are passed through, run exactly what the user asked for
// after setup and skip the built-in gateway+TUI convenience flow.
if len(args) > 0 {
cmd := exec.Command(bin, args...)
cmd.Env = openclawEnv()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return windowsHint(err)
}
if firstLaunch {
if err := integrationOnboarded("openclaw"); err != nil {
return fmt.Errorf("failed to save onboarding state: %w", err)
}
}
return nil
}
return err
token, port := c.gatewayInfo()
addr := fmt.Sprintf("localhost:%d", port)
// If the gateway is already running (e.g. via the daemon), restart it
// so it picks up any config changes from Edit() above (model, provider, etc.).
if portOpen(addr) {
restart := exec.Command(bin, "daemon", "restart")
restart.Env = openclawEnv()
if err := restart.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
}
if !waitForPort(addr, 10*time.Second) {
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
}
}
// If the gateway isn't running, start it as a background child process.
if !portOpen(addr) {
gw := exec.Command(bin, "gateway", "run", "--force")
gw.Env = openclawEnv()
if err := gw.Start(); err != nil {
return windowsHint(fmt.Errorf("failed to start gateway: %w", err))
}
defer func() {
if gw.Process != nil {
_ = gw.Process.Kill()
_ = gw.Wait()
}
}()
}
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
if !waitForPort(addr, 30*time.Second) {
return windowsHint(fmt.Errorf("gateway did not start on %s", addr))
}
printOpenclawReady(bin, token, port, firstLaunch)
tuiArgs := []string{"tui"}
if firstLaunch {
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
}
tui := exec.Command(bin, tuiArgs...)
tui.Env = openclawEnv()
tui.Stdin = os.Stdin
tui.Stdout = os.Stdout
tui.Stderr = os.Stderr
if err := tui.Run(); err != nil {
return windowsHint(err)
}
if firstLaunch {
if err := integrationOnboarded("openclaw"); err != nil {
return fmt.Errorf("failed to save onboarding state: %w", err)
}
}
return nil
}
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
func (c *Openclaw) gatewayInfo() (token string, port int) {
port = defaultGatewayPort
home, err := os.UserHomeDir()
if err != nil {
return "", port
}
for _, path := range []string{
filepath.Join(home, ".openclaw", "openclaw.json"),
filepath.Join(home, ".clawdbot", "clawdbot.json"),
} {
data, err := os.ReadFile(path)
if err != nil {
continue
}
var config map[string]any
if json.Unmarshal(data, &config) != nil {
continue
}
gw, _ := config["gateway"].(map[string]any)
if p, ok := gw["port"].(float64); ok && p > 0 {
port = int(p)
}
auth, _ := gw["auth"].(map[string]any)
if t, _ := auth["token"].(string); t != "" {
token = t
}
return token, port
}
return "", port
}
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
u := fmt.Sprintf("http://localhost:%d", port)
if token != "" {
u += "/#token=" + url.QueryEscape(token)
}
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
if firstLaunch {
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
} else {
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
}
}
// openclawEnv returns the current environment with provider API keys cleared
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
func openclawEnv() []string {
clear := map[string]bool{
"ANTHROPIC_API_KEY": true,
"ANTHROPIC_OAUTH_TOKEN": true,
"OPENAI_API_KEY": true,
"GEMINI_API_KEY": true,
"MISTRAL_API_KEY": true,
"GROQ_API_KEY": true,
"XAI_API_KEY": true,
"OPENROUTER_API_KEY": true,
}
var env []string
for _, e := range os.Environ() {
key, _, _ := strings.Cut(e, "=")
if !clear[key] {
env = append(env, e)
}
}
return env
}
// portOpen checks if a TCP port is currently accepting connections.
func portOpen(addr string) bool {
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
if err != nil {
return false
}
conn.Close()
return true
}
func waitForPort(addr string, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
if err == nil {
conn.Close()
return true
}
time.Sleep(250 * time.Millisecond)
}
return false
}
func windowsHint(err error) error {
if runtime.GOOS != "windows" {
return err
}
return fmt.Errorf("%w\n\n"+
"OpenClaw runs best on WSL2.\n"+
"Quick setup: wsl --install\n"+
"Guide: https://docs.openclaw.ai/windows", err)
}
// onboarded checks if OpenClaw onboarding wizard was completed
@@ -107,6 +313,144 @@ func (c *Openclaw) onboarded() bool {
return lastRunAt != ""
}
// patchDeviceScopes upgrades the local CLI device's paired scopes to include
// operator.admin. Only patches the local device, not remote ones.
// Best-effort: silently returns on any error.
func patchDeviceScopes() {
home, err := os.UserHomeDir()
if err != nil {
return
}
deviceID := readLocalDeviceID(home)
if deviceID == "" {
return
}
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
data, err := os.ReadFile(path)
if err != nil {
return
}
var devices map[string]map[string]any
if err := json.Unmarshal(data, &devices); err != nil {
return
}
dev, ok := devices[deviceID]
if !ok {
return
}
required := []string{
"operator.read",
"operator.admin",
"operator.approvals",
"operator.pairing",
}
changed := patchScopes(dev, "scopes", required)
if tokens, ok := dev["tokens"].(map[string]any); ok {
for _, tok := range tokens {
if tokenMap, ok := tok.(map[string]any); ok {
if patchScopes(tokenMap, "scopes", required) {
changed = true
}
}
}
}
if !changed {
return
}
out, err := json.MarshalIndent(devices, "", " ")
if err != nil {
return
}
_ = os.WriteFile(path, out, 0o600)
}
// readLocalDeviceID reads the local device ID from openclaw's identity file.
func readLocalDeviceID(home string) string {
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
if err != nil {
return ""
}
var auth map[string]any
if err := json.Unmarshal(data, &auth); err != nil {
return ""
}
id, _ := auth["deviceId"].(string)
return id
}
// patchScopes ensures obj[key] contains all required scopes. Returns true if
// any scopes were added.
func patchScopes(obj map[string]any, key string, required []string) bool {
existing, _ := obj[key].([]any)
have := make(map[string]bool, len(existing))
for _, s := range existing {
if str, ok := s.(string); ok {
have[str] = true
}
}
added := false
for _, s := range required {
if !have[s] {
existing = append(existing, s)
added = true
}
}
if added {
obj[key] = existing
}
return added
}
func ensureOpenclawInstalled() (string, error) {
if _, err := exec.LookPath("openclaw"); err == nil {
return "openclaw", nil
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return "clawdbot", nil
}
if _, err := exec.LookPath("npm"); err != nil {
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
"Install Node.js first:\n" +
" https://nodejs.org/\n\n" +
"Then rerun:\n" +
" ollama launch\n" +
"and select OpenClaw")
}
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("openclaw installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install openclaw: %w", err)
}
if _, err := exec.LookPath("openclaw"); err != nil {
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
return "openclaw", nil
}
func (c *Openclaw) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".openclaw", "openclaw.json")
@@ -161,8 +505,7 @@ func (c *Openclaw) Edit(models []string) error {
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
// needed to register provider
ollama["apiKey"] = "ollama-local"
// TODO(parthsareen): potentially move to responses
ollama["api"] = "openai-completions"
ollama["api"] = "ollama"
// Build map of existing models to preserve user customizations
existingModels, _ := ollama["models"].([]any)
@@ -175,25 +518,13 @@ func (c *Openclaw) Edit(models []string) error {
}
}
client, _ := api.ClientFromEnvironment()
var newModels []any
for _, model := range models {
entry := map[string]any{
"id": model,
"name": model,
"reasoning": false,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
// TODO(parthsareen): get these values from API
"contextWindow": 131072,
"maxTokens": 16384,
}
for _, m := range models {
entry, _ := openclawModelConfig(context.Background(), client, m)
// Merge existing fields (user customizations)
if existing, ok := existingByID[model]; ok {
if existing, ok := existingByID[m]; ok {
for k, v := range existing {
if _, isNew := entry[k]; !isNew {
entry[k] = v
@@ -230,7 +561,213 @@ func (c *Openclaw) Edit(models []string) error {
if err != nil {
return err
}
return writeWithBackup(configPath, data)
if err := writeWithBackup(configPath, data); err != nil {
return err
}
// Clear any per-session model overrides so the new primary takes effect
// immediately rather than being shadowed by a cached modelOverride.
clearSessionModelOverride(models[0])
return nil
}
// clearSessionModelOverride removes per-session model overrides from the main
// agent session so the global primary model takes effect on the next TUI launch.
func clearSessionModelOverride(primary string) {
home, err := os.UserHomeDir()
if err != nil {
return
}
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
data, err := os.ReadFile(path)
if err != nil {
return
}
var sessions map[string]map[string]any
if json.Unmarshal(data, &sessions) != nil {
return
}
changed := false
for _, sess := range sessions {
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
delete(sess, "modelOverride")
delete(sess, "providerOverride")
sess["model"] = primary
changed = true
}
}
if !changed {
return
}
out, err := json.MarshalIndent(sessions, "", " ")
if err != nil {
return
}
_ = os.WriteFile(path, out, 0o600)
}
const webSearchNpmPackage = "@ollama/openclaw-web-search"
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
// present. Returns true if the extension is available.
func ensureWebSearchPlugin() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil {
return true // already installed
}
npmBin, err := exec.LookPath("npm")
if err != nil {
return false
}
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
return false
}
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
out, err := pack.Output()
if err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
tgzName := strings.TrimSpace(string(out))
tgzPath := filepath.Join(pluginDir, tgzName)
defer os.Remove(tgzPath)
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
if err := tar.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
return true
}
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
// config so the gateway activates it on next start. Best-effort; silently returns
// on any error.
func registerWebSearchPlugin() {
home, err := os.UserHomeDir()
if err != nil {
return
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
data, err := os.ReadFile(configPath)
if err != nil {
return
}
var config map[string]any
if json.Unmarshal(data, &config) != nil {
return
}
plugins, _ := config["plugins"].(map[string]any)
if plugins == nil {
plugins = make(map[string]any)
}
entries, _ := plugins["entries"].(map[string]any)
if entries == nil {
entries = make(map[string]any)
}
if _, ok := entries["openclaw-web-search"]; ok {
return // already registered
}
entries["openclaw-web-search"] = map[string]any{"enabled": true}
plugins["entries"] = entries
config["plugins"] = plugins
// Disable the built-in web search since our plugin replaces it.
tools, _ := config["tools"].(map[string]any)
if tools == nil {
tools = make(map[string]any)
}
web, _ := tools["web"].(map[string]any)
if web == nil {
web = make(map[string]any)
}
web["search"] = map[string]any{"enabled": false}
tools["web"] = web
config["tools"] = tools
out, err := json.MarshalIndent(config, "", " ")
if err != nil {
return
}
_ = os.WriteFile(configPath, out, 0o600)
}
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
// The second return value indicates whether the model is a cloud (remote) model.
func openclawModelConfig(ctx context.Context, client *api.Client, modelID string) (map[string]any, bool) {
entry := map[string]any{
"id": modelID,
"name": modelID,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
}
if client == nil {
return entry, false
}
showCtx := ctx
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
showCtx, cancel = context.WithTimeout(ctx, openclawModelShowTimeout)
defer cancel()
}
resp, err := client.Show(showCtx, &api.ShowRequest{Model: modelID})
if err != nil {
return entry, false
}
// Set input types based on vision capability
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
entry["input"] = []any{"text", "image"}
}
// Set reasoning based on thinking capability
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
entry["reasoning"] = true
}
// Cloud models: use hardcoded limits for context/output tokens.
// Capability detection above still applies (vision, thinking).
if resp.RemoteModel != "" {
if l, ok := lookupCloudModelLimit(modelID); ok {
entry["contextWindow"] = l.Context
entry["maxTokens"] = l.Output
}
return entry, true
}
// Extract context window from ModelInfo (local models only)
for key, val := range resp.ModelInfo {
if strings.HasSuffix(key, ".context_length") {
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
entry["contextWindow"] = int(ctxLen)
}
break
}
}
return entry, false
}
func (c *Openclaw) Models() []string {

View File

@@ -1,11 +1,21 @@
package config
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestOpenclawIntegration(t *testing.T) {
@@ -26,6 +36,124 @@ func TestOpenclawIntegration(t *testing.T) {
})
}
func TestOpenclawRunPassthroughArgs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses a POSIX shell test binary")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
if err := integrationOnboarded("openclaw"); err != nil {
t.Fatal(err)
}
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
}`), 0o644); err != nil {
t.Fatal(err)
}
bin := filepath.Join(tmpDir, "openclaw")
if err := os.WriteFile(bin, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$HOME/invocations.log\"\n"), 0o755); err != nil {
t.Fatal(err)
}
c := &Openclaw{}
if err := c.Run("llama3.2", []string{"gateway", "--someflag"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
data, err := os.ReadFile(filepath.Join(tmpDir, "invocations.log"))
if err != nil {
t.Fatal(err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 1 {
t.Fatalf("expected exactly 1 invocation, got %d: %v", len(lines), lines)
}
if lines[0] != "gateway --someflag" {
t.Fatalf("invocation = %q, want %q", lines[0], "gateway --someflag")
}
}
func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses a POSIX shell test binary")
}
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
return true, nil
}
defer func() { DefaultConfirmPrompt = oldHook }()
t.Run("success persists onboarding flag", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
// Mark OpenClaw onboarding complete so Run takes passthrough path directly.
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatal(err)
}
c := &Openclaw{}
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
integrationConfig, err := loadIntegration("openclaw")
if err != nil {
t.Fatalf("loadIntegration() error = %v", err)
}
if !integrationConfig.Onboarded {
t.Fatal("expected onboarding flag to be persisted after successful run")
}
})
t.Run("failure does not persist onboarding flag", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
t.Fatal(err)
}
c := &Openclaw{}
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil {
t.Fatal("expected run failure")
}
integrationConfig, err := loadIntegration("openclaw")
if err == nil && integrationConfig.Onboarded {
t.Fatal("expected onboarding flag to remain unset after failed run")
}
})
}
func TestOpenclawEdit(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
@@ -359,19 +487,16 @@ func TestOpenclawEditSchemaFields(t *testing.T) {
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
// Verify required schema fields
if entry["reasoning"] != false {
t.Error("reasoning should be false")
// Verify base schema fields (always set regardless of API availability)
if entry["id"] != "llama3.2" {
t.Errorf("id = %v, want llama3.2", entry["id"])
}
if entry["name"] != "llama3.2" {
t.Errorf("name = %v, want llama3.2", entry["name"])
}
if entry["input"] == nil {
t.Error("input should be set")
}
if entry["contextWindow"] == nil {
t.Error("contextWindow should be set")
}
if entry["maxTokens"] == nil {
t.Error("maxTokens should be set")
}
cost := entry["cost"].(map[string]any)
if cost["cacheRead"] == nil {
t.Error("cost.cacheRead should be set")
@@ -876,3 +1001,589 @@ func TestOpenclawOnboarded(t *testing.T) {
}
})
}
func TestOpenclawGatewayInfo(t *testing.T) {
c := &Openclaw{}
t.Run("returns defaults when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
token, port := c.gatewayInfo()
if token != "" {
t.Errorf("expected empty token, got %q", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
}
})
t.Run("reads token and port from config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"gateway": {
"port": 9999,
"auth": {"mode": "token", "token": "my-secret"}
}
}`), 0o644)
token, port := c.gatewayInfo()
if token != "my-secret" {
t.Errorf("expected token %q, got %q", "my-secret", token)
}
if port != 9999 {
t.Errorf("expected port 9999, got %d", port)
}
})
t.Run("uses default port when not in config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"gateway": {"auth": {"token": "tok"}}
}`), 0o644)
token, port := c.gatewayInfo()
if token != "tok" {
t.Errorf("expected token %q, got %q", "tok", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
}
})
t.Run("falls back to legacy clawdbot config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
"gateway": {"port": 12345, "auth": {"token": "legacy-token"}}
}`), 0o644)
token, port := c.gatewayInfo()
if token != "legacy-token" {
t.Errorf("expected token %q, got %q", "legacy-token", token)
}
if port != 12345 {
t.Errorf("expected port 12345, got %d", port)
}
})
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
token, port := c.gatewayInfo()
if token != "" {
t.Errorf("expected empty token, got %q", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port, got %d", port)
}
})
t.Run("handles missing gateway section", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
token, port := c.gatewayInfo()
if token != "" {
t.Errorf("expected empty token, got %q", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port, got %d", port)
}
})
}
func TestPrintOpenclawReady(t *testing.T) {
t.Run("includes port in URL", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "", 9999, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "localhost:9999") {
t.Errorf("expected port 9999 in output, got:\n%s", output)
}
if strings.Contains(output, "#token=") {
t.Error("should not include token fragment when token is empty")
}
})
t.Run("URL-escapes token", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "my token&special=chars", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
escaped := url.QueryEscape("my token&special=chars")
if !strings.Contains(output, "#token="+escaped) {
t.Errorf("expected URL-escaped token %q in output, got:\n%s", escaped, output)
}
})
t.Run("simple token is not mangled", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "#token=ollama") {
t.Errorf("expected #token=ollama in output, got:\n%s", output)
}
})
t.Run("includes web UI hint", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "Open the Web UI") {
t.Errorf("expected web UI hint in output, got:\n%s", output)
}
})
t.Run("first launch shows quick start tips", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, true)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
for _, want := range []string{"/help", "channels", "skills", "gateway"} {
if !strings.Contains(output, want) {
t.Errorf("expected %q in first-launch output, got:\n%s", want, output)
}
}
})
t.Run("subsequent launch shows single tip", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "Tip:") {
t.Errorf("expected single tip line, got:\n%s", output)
}
if strings.Contains(output, "Quick start") {
t.Errorf("should not show quick start on subsequent launch")
}
})
}
func TestOpenclawModelConfig(t *testing.T) {
t.Run("nil client returns base config", func(t *testing.T) {
cfg, _ := openclawModelConfig(context.Background(), nil, "llama3.2")
if cfg["id"] != "llama3.2" {
t.Errorf("id = %v, want llama3.2", cfg["id"])
}
if cfg["name"] != "llama3.2" {
t.Errorf("name = %v, want llama3.2", cfg["name"])
}
if cfg["cost"] == nil {
t.Error("cost should be set")
}
// Should not have capability fields without API
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set without API")
}
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set without API")
}
})
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{"llama.context_length":4096}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "llava:7b")
input, ok := cfg["input"].([]any)
if !ok || len(input) != 2 {
t.Errorf("input = %v, want [text image]", cfg["input"])
}
})
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
input, ok := cfg["input"].([]any)
if !ok || len(input) != 1 {
t.Errorf("input = %v, want [text]", cfg["input"])
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set for non-thinking model")
}
})
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "qwq")
if cfg["reasoning"] != true {
t.Error("expected reasoning = true for thinking model")
}
})
t.Run("extracts context window from model info", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
if cfg["contextWindow"] != 131072 {
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
}
})
t.Run("handles all capabilities together", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "qwen3-vision")
input, ok := cfg["input"].([]any)
if !ok || len(input) != 2 {
t.Errorf("input = %v, want [text image]", cfg["input"])
}
if cfg["reasoning"] != true {
t.Error("expected reasoning = true")
}
if cfg["contextWindow"] != 32768 {
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
}
})
t.Run("returns base config when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "missing-model")
if cfg["id"] != "missing-model" {
t.Errorf("id = %v, want missing-model", cfg["id"])
}
// Should still have input (default)
if cfg["input"] == nil {
t.Error("input should always be set")
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set when show fails")
}
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set when show fails")
}
})
t.Run("times out slow show and returns base config", func(t *testing.T) {
oldTimeout := openclawModelShowTimeout
openclawModelShowTimeout = 50 * time.Millisecond
t.Cleanup(func() { openclawModelShowTimeout = oldTimeout })
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
time.Sleep(300 * time.Millisecond)
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{"llama.context_length":4096}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
start := time.Now()
cfg, _ := openclawModelConfig(context.Background(), client, "slow-model")
elapsed := time.Since(start)
if elapsed >= 250*time.Millisecond {
t.Fatalf("openclawModelConfig took too long: %v", elapsed)
}
if cfg["id"] != "slow-model" {
t.Errorf("id = %v, want slow-model", cfg["id"])
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set on timeout")
}
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set on timeout")
}
})
t.Run("skips zero context length", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "test-model")
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set for zero value")
}
})
t.Run("cloud model uses hardcoded limits", func(t *testing.T) {
// Use a model name that's in cloudModelLimits and make the server
// report it as a remote/cloud model
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.5"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.5:cloud")
if !isCloud {
t.Error("expected isCloud = true for cloud model")
}
if cfg["contextWindow"] != 204_800 {
t.Errorf("contextWindow = %v, want 204800", cfg["contextWindow"])
}
if cfg["maxTokens"] != 128_000 {
t.Errorf("maxTokens = %v, want 128000", cfg["maxTokens"])
}
})
t.Run("cloud model with vision capability gets image input", func(t *testing.T) {
// Regression test: cloud models must not skip capability detection.
// A cloud model that reports vision capability should have input: [text, image].
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{},"remote_model":"qwen3-vl"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwen3-vl:235b-cloud")
if !isCloud {
t.Error("expected isCloud = true for cloud vision model")
}
input, ok := cfg["input"].([]any)
if !ok || len(input) != 2 {
t.Errorf("input = %v, want [text image] for cloud vision model", cfg["input"])
}
})
t.Run("cloud model with thinking capability gets reasoning flag", func(t *testing.T) {
// Regression test: cloud models must not skip capability detection.
// A cloud model that reports thinking capability should have reasoning: true.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{},"remote_model":"qwq-cloud"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwq:cloud")
if !isCloud {
t.Error("expected isCloud = true for cloud thinking model")
}
if cfg["reasoning"] != true {
t.Error("expected reasoning = true for cloud thinking model")
}
})
}
func TestIntegrationOnboarded(t *testing.T) {
t.Run("returns false when not set", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
integrationConfig, err := loadIntegration("openclaw")
if err == nil && integrationConfig.Onboarded {
t.Error("expected false for fresh config")
}
})
t.Run("returns true after integrationOnboarded", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
if err := integrationOnboarded("openclaw"); err != nil {
t.Fatal(err)
}
integrationConfig, err := loadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded {
t.Error("expected true after integrationOnboarded")
}
})
t.Run("is case insensitive", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
if err := integrationOnboarded("OpenClaw"); err != nil {
t.Fatal(err)
}
integrationConfig, err := loadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded {
t.Error("expected true when set with different case")
}
})
t.Run("preserves existing integration data", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
if err := SaveIntegration("openclaw", []string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
if err := integrationOnboarded("openclaw"); err != nil {
t.Fatal(err)
}
// Verify onboarded is set
integrationConfig, err := loadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded {
t.Error("expected true after integrationOnboarded")
}
// Verify models are preserved
model := IntegrationModel("openclaw")
if model != "llama3.2" {
t.Errorf("expected first model llama3.2, got %q", model)
}
})
}

View File

@@ -10,10 +10,11 @@ import (
// ANSI escape sequences for terminal formatting.
const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiYellow = "\033[33m"
)
// ErrCancelled is returned when the user cancels a selection.

View File

@@ -365,14 +365,27 @@ func (m selectorModel) View() string {
return s
}
func SelectSingle(title string, items []SelectItem) (string, error) {
// cursorForCurrent returns the item index matching current, or 0 if not found.
func cursorForCurrent(items []SelectItem, current string) int {
if current != "" {
for i, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
return i
}
}
}
return 0
}
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
if len(items) == 0 {
return "", fmt.Errorf("no items to select from")
}
m := selectorModel{
title: title,
items: items,
title: title,
items: items,
cursor: cursorForCurrent(items, current),
}
p := tea.NewProgram(m)
@@ -402,6 +415,12 @@ type multiSelectorModel struct {
cancelled bool
confirmed bool
width int
// multi enables full multi-select editing mode. The zero value (false)
// shows a single-select picker where Enter adds the chosen model to
// the existing list. Tab toggles between modes.
multi bool
singleAdd string // model picked in single mode
}
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
@@ -416,13 +435,23 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
m.itemIndex[item.Name] = i
}
for _, name := range preChecked {
if idx, ok := m.itemIndex[name]; ok {
// Reverse order so preChecked[0] (the current default) ends up last
// in checkOrder, matching the "last checked = default" convention.
for i := len(preChecked) - 1; i >= 0; i-- {
if idx, ok := m.itemIndex[preChecked[i]]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
// Position cursor on the current default model
if len(preChecked) > 0 {
if idx, ok := m.itemIndex[preChecked[0]]; ok {
m.cursor = idx
m.updateScroll(m.otherStart())
}
}
return m
}
@@ -533,14 +562,25 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.cancelled = true
return m, tea.Quit
case tea.KeyTab:
m.multi = !m.multi
case tea.KeyEnter:
if len(m.checkOrder) > 0 {
if !m.multi {
if len(filtered) > 0 && m.cursor < len(filtered) {
m.singleAdd = filtered[m.cursor].Name
m.confirmed = true
return m, tea.Quit
}
} else if len(m.checkOrder) > 0 {
m.confirmed = true
return m, tea.Quit
}
case tea.KeySpace:
m.toggleItem()
if m.multi {
m.toggleItem()
}
case tea.KeyUp:
if m.cursor > 0 {
@@ -579,7 +619,9 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// On some terminals (e.g. Windows PowerShell), space arrives as
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
m.toggleItem()
if m.multi {
m.toggleItem()
}
} else {
m.filter += string(msg.Runes)
m.cursor = 0
@@ -591,6 +633,19 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
}
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
s.WriteString("\n")
if item.Description != "" {
s.WriteString(selectorDescLineStyle.Render(item.Description))
s.WriteString("\n")
}
}
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
origIdx := m.itemIndex[item.Name]
@@ -602,7 +657,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
}
suffix := ""
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
suffix = " " + selectorDefaultTagStyle.Render("(default)")
}
@@ -624,6 +679,11 @@ func (m multiSelectorModel) View() string {
return ""
}
renderItem := m.renderSingleItem
if m.multi {
renderItem = m.renderMultiItem
}
var s strings.Builder
s.WriteString(selectorTitleStyle.Render(m.title))
@@ -648,7 +708,7 @@ func (m multiSelectorModel) View() string {
if idx >= len(filtered) {
break
}
m.renderMultiItem(&s, filtered[idx], idx)
renderItem(&s, filtered[idx], idx)
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
@@ -671,7 +731,7 @@ func (m multiSelectorModel) View() string {
s.WriteString(sectionHeaderStyle.Render("Recommended"))
s.WriteString("\n")
for _, idx := range recItems {
m.renderMultiItem(&s, filtered[idx], idx)
renderItem(&s, filtered[idx], idx)
}
}
@@ -691,7 +751,7 @@ func (m multiSelectorModel) View() string {
if idx >= len(otherItems) {
break
}
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
}
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
@@ -703,15 +763,18 @@ func (m multiSelectorModel) View() string {
s.WriteString("\n")
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
if !m.multi {
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
result := s.String()
if m.width > 0 {
@@ -734,18 +797,28 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
}
fm := finalModel.(multiSelectorModel)
if fm.cancelled {
if fm.cancelled || !fm.confirmed {
return nil, ErrCancelled
}
if !fm.confirmed {
return nil, ErrCancelled
// Single-add mode: prepend the picked model, keep existing models deduped
if fm.singleAdd != "" {
result := []string{fm.singleAdd}
for _, name := range preChecked {
if name != fm.singleAdd {
result = append(result, name)
}
}
return result, nil
}
var result []string
// Multi-edit mode: last checked is default (first in result)
last := fm.checkOrder[len(fm.checkOrder)-1]
result := []string{fm.items[last].Name}
for _, idx := range fm.checkOrder {
result = append(result, fm.items[idx].Name)
if idx != last {
result = append(result, fm.items[idx].Name)
}
}
return result, nil
}

View File

@@ -382,6 +382,42 @@ func TestUpdateNavigation_Backspace(t *testing.T) {
}
}
// --- cursorForCurrent ---
func TestCursorForCurrent(t *testing.T) {
testItems := []SelectItem{
{Name: "llama3.2", Recommended: true},
{Name: "qwen3:8b", Recommended: true},
{Name: "gemma3:latest"},
{Name: "deepseek-r1"},
{Name: "glm-5:cloud"},
}
tests := []struct {
name string
current string
want int
}{
{"empty current", "", 0},
{"exact match", "qwen3:8b", 1},
{"no match returns 0", "nonexistent", 0},
{"bare name matches with :latest suffix", "gemma3", 2},
{"full tag matches bare item", "llama3.2:latest", 0},
{"cloud model exact match", "glm-5:cloud", 4},
{"cloud model bare name", "glm-5", 4},
{"recommended item exact match", "llama3.2", 0},
{"recommended item with tag", "qwen3", 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := cursorForCurrent(testItems, tt.current); got != tt.want {
t.Errorf("cursorForCurrent(%q) = %d, want %d", tt.current, got, tt.want)
}
})
}
}
// --- ReorderItems ---
func TestReorderItems(t *testing.T) {
@@ -503,6 +539,7 @@ func TestMultiView_CursorIndicator(t *testing.T) {
func TestMultiView_CheckedItemShowsX(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
m.multi = true
content := m.View()
if !strings.Contains(content, "[x]") {
@@ -514,11 +551,18 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
}
func TestMultiView_DefaultTag(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
m.multi = true
content := m.View()
if !strings.Contains(content, "(default)") {
t.Error("first checked item should have (default) tag")
t.Error("should have (default) tag")
}
// preChecked[0] ("a") should be the default (last in checkOrder)
aIdx := strings.Index(content, "a")
defaultIdx := strings.Index(content, "(default)")
if defaultIdx < aIdx {
t.Error("(default) tag should appear after 'a' (the current default)")
}
}
@@ -549,6 +593,7 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.multi = true
m.cursor = 1
// Simulate space delivered as tea.KeySpace
@@ -565,6 +610,7 @@ func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.multi = true
m.cursor = 1
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
@@ -582,6 +628,161 @@ func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
}
}
// --- Single-add mode ---
func TestMulti_StartsInSingleMode(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
if m.multi {
t.Error("should start in single mode (multi=false)")
}
}
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
content := m.View()
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
t.Error("single mode should not show checkboxes")
}
if !strings.Contains(content, "▸") {
t.Error("single mode should show cursor indicator")
}
}
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.cursor = 1
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
m = updated.(multiSelectorModel)
if m.singleAdd != "b" {
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
}
if !m.confirmed {
t.Error("should set confirmed")
}
}
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
m.cursor = 0
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
m = updated.(multiSelectorModel)
if len(m.checked) != 0 {
t.Error("space in single mode should not toggle items")
}
}
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
m.cursor = 0
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
m = updated.(multiSelectorModel)
if len(m.checked) != 0 {
t.Error("space rune in single mode should not toggle items")
}
if m.filter != "" {
t.Error("space rune in single mode should not add to filter")
}
}
func TestMulti_TabTogglesMode(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
if m.multi {
t.Fatal("should start in single mode")
}
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = updated.(multiSelectorModel)
if !m.multi {
t.Error("tab should switch to multi mode")
}
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = updated.(multiSelectorModel)
if m.multi {
t.Error("tab should switch back to single mode")
}
}
func TestMulti_SingleModeHelpText(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a"), nil)
content := m.View()
if !strings.Contains(content, "tab add multiple") {
t.Error("single mode should show 'tab add multiple' in help")
}
}
func TestMulti_MultiModeHelpText(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a"), nil)
m.multi = true
content := m.View()
if !strings.Contains(content, "tab select single") {
t.Error("multi mode should show 'tab select single' in help")
}
}
// --- preChecked initialization order ---
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
// preChecked[0] ("a") is the current default and should end up
// last in checkOrder so it gets the (default) tag.
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
if len(m.checkOrder) != 3 {
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
}
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "a" {
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
}
}
func TestMulti_CursorOnDefaultModel(t *testing.T) {
// preChecked[0] ("b") is the default; cursor should start on it
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
if m.cursor != 1 {
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
}
}
// --- Multi-mode last-checked is default ---
func TestMulti_LastCheckedIsDefault(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
m.multi = true
// Check "alpha" then "gamma"
m.cursor = 0
m.toggleItem()
m.cursor = 2
m.toggleItem()
// Last checked ("gamma") should be at the end of checkOrder
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "gamma" {
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
}
// The (default) tag renders based on checkOrder[len-1]
content := m.View()
if !strings.Contains(content, "(default)") {
t.Fatal("should show (default) tag")
}
// "alpha" line should NOT have the default tag
for _, line := range strings.Split(content, "\n") {
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
t.Error("'alpha' (first checked) should not have (default) tag")
}
}
}
// Key message helpers for testing
type keyType = int

View File

@@ -429,8 +429,24 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
if m.multiModalSelector.confirmed {
var selected []string
for _, idx := range m.multiModalSelector.checkOrder {
selected = append(selected, m.multiModalSelector.items[idx].Name)
if m.multiModalSelector.singleAdd != "" {
// Single-add mode: prepend picked model, keep existing deduped
selected = []string{m.multiModalSelector.singleAdd}
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
if name != m.multiModalSelector.singleAdd {
selected = append(selected, name)
}
}
} else {
// Last checked is default (first in result)
co := m.multiModalSelector.checkOrder
last := co[len(co)-1]
selected = []string{m.multiModalSelector.items[last].Name}
for _, idx := range co {
if idx != last {
selected = append(selected, m.multiModalSelector.items[idx].Name)
}
}
}
if len(selected) > 0 {
m.changeModels = selected
@@ -508,7 +524,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case "enter", " ":
item := m.items[m.cursor]
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) {
return m, nil
}
@@ -539,6 +555,12 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
item := m.items[m.cursor]
if item.integration != "" || item.isRunModel {
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
if config.AutoInstallable(item.integration) {
// Auto-installable: select to trigger install flow
m.selected = true
m.quitting = true
return m, tea.Quit
}
return m, nil
}
if item.integration != "" && config.IsEditorIntegration(item.integration) {
@@ -602,7 +624,11 @@ func (m model) View() string {
var modelSuffix string
if item.integration != "" {
if !isInstalled {
title += " " + notInstalledStyle.Render("(not installed)")
if config.AutoInstallable(item.integration) {
title += " " + notInstalledStyle.Render("(install)")
} else {
title += " " + notInstalledStyle.Render("(not installed)")
}
} else if m.cursor == i {
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
modelSuffix = " " + modelStyle.Render("("+mdl+")")
@@ -618,7 +644,9 @@ func (m model) View() string {
desc := item.description
if !isInstalled && item.integration != "" && m.cursor == i {
if hint := config.IntegrationInstallHint(item.integration); hint != "" {
if config.AutoInstallable(item.integration) {
desc = "Press enter to install"
} else if hint := config.IntegrationInstallHint(item.integration); hint != "" {
desc = hint
} else {
desc = "not installed"

View File

@@ -106,20 +106,23 @@
"group": "Integrations",
"pages": [
"/integrations/index",
{
"group": "Assistants",
"expanded": true,
"pages": [
"/integrations/openclaw"
]
},
{
"group": "Coding",
"expanded": true,
"pages": [
"/integrations/claude-code",
"/integrations/codex",
"/integrations/opencode",
"/integrations/droid",
"/integrations/goose"
]
},
{
"group": "Assistants",
"pages": [
"/integrations/openclaw"
"/integrations/goose",
"/integrations/pi"
]
},
{

View File

@@ -13,6 +13,7 @@ Coding assistants that can read, modify, and execute code in your projects.
- [OpenCode](/integrations/opencode)
- [Droid](/integrations/droid)
- [Goose](/integrations/goose)
- [Pi](/integrations/pi)
## Assistants

View File

@@ -4,47 +4,65 @@ title: OpenClaw
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
## Install
Install [OpenClaw](https://openclaw.ai/)
```bash
npm install -g openclaw@latest
```
Then run the onboarding wizard:
```bash
openclaw onboard --install-daemon
```
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
## Quick start
```bash
ollama launch openclaw
```
Ollama handles everything automatically:
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
2. **Security** — On the first launch, a security notice explains the risks of tool access
3. **Model** — Pick a model from the selector (local or cloud)
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
This configures OpenClaw to use Ollama and starts the gateway.
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
## Configure without launching
To change the model without starting the gateway and TUI:
To configure without launching:
```shell
```bash
ollama launch openclaw --config
```
## Recommended Models
To use a specific model directly:
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
```bash
ollama launch openclaw --model kimi-k2.5:cloud
```
If the gateway is already running, it restarts automatically to pick up the new model.
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
- `glm-5:cloud` — Reasoning and code generation
**Local models:**
- `glm-4.7-flash` — Reasoning and code generation locally (~25 GB VRAM)
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
## Connect messaging apps
```bash
openclaw configure --section channels
```
Link WhatsApp, Telegram, Slack, Discord, or iMessage to chat with your local models from anywhere.
## Stopping the gateway
```bash
openclaw gateway stop
```
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).

57
docs/integrations/pi.mdx Normal file
View File

@@ -0,0 +1,57 @@
---
title: Pi
---
Pi is a minimal AI agent toolkit with plugin support.
## Install
Install [Pi](https://github.com/badlogic/pi-mono):
```bash
npm install -g @mariozechner/pi-coding-agent
```
## Usage with Ollama
### Quick setup
```bash
ollama launch pi
```
To configure without launching:
```shell
ollama launch pi --config
```
### Manual setup
Add a configuration block to `~/.pi/agent/models.json`:
```json
{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{
"id": "qwen3-coder"
}
]
}
}
}
```
Update `~/.pi/agent/settings.json` to set the default provider:
```json
{
"defaultProvider": "ollama",
"defaultModel": "qwen3-coder"
}
```

View File

@@ -27,9 +27,17 @@ The menu provides quick access to:
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
- **Additional integrations** - Available under "More..."
## Assistants
Launch [OpenClaw](/integrations/openclaw), a personal AI with 100+ skills:
```sh
ollama launch openclaw
```
## Coding
Launch coding tools with Ollama models:
Launch [Claude Code](/integrations/claude-code) and other coding tools with Ollama models:
```sh
ollama launch claude

1
go.mod
View File

@@ -26,6 +26,7 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/klauspost/compress v1.18.3
github.com/mattn/go-runewidth v0.0.16
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

4
go.sum
View File

@@ -122,7 +122,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
@@ -150,8 +149,9 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -496,6 +497,17 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
func ResponsesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader("Content-Encoding") == "zstd" {
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
return
}
defer reader.Close()
c.Request.Body = io.NopCloser(reader)
c.Request.Header.Del("Content-Encoding")
}
var req openai.ResponsesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))

View File

@@ -14,6 +14,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -1238,3 +1239,102 @@ func TestImageEditsMiddleware(t *testing.T) {
})
}
}
func zstdCompress(t *testing.T, data []byte) []byte {
t.Helper()
var buf bytes.Buffer
w, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(data); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
return buf.Bytes()
}
func TestResponsesMiddlewareZstd(t *testing.T) {
tests := []struct {
name string
body string
useZstd bool
oversized bool
wantCode int
wantModel string
wantMessage string
}{
{
name: "plain JSON",
body: `{"model": "test-model", "input": "Hello"}`,
wantCode: http.StatusOK,
wantModel: "test-model",
wantMessage: "Hello",
},
{
name: "zstd compressed",
body: `{"model": "test-model", "input": "Hello"}`,
useZstd: true,
wantCode: http.StatusOK,
wantModel: "test-model",
wantMessage: "Hello",
},
{
name: "zstd over max decompressed size",
oversized: true,
useZstd: true,
wantCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedRequest *api.ChatRequest
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
c.Status(http.StatusOK)
})
var bodyReader io.Reader
if tt.oversized {
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
} else if tt.useZstd {
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
} else {
bodyReader = strings.NewReader(tt.body)
}
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
req.Header.Set("Content-Type", "application/json")
if tt.useZstd || tt.oversized {
req.Header.Set("Content-Encoding", "zstd")
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.wantCode {
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
}
if tt.wantCode != http.StatusOK {
return
}
if capturedRequest == nil {
t.Fatal("expected captured request, got nil")
}
if capturedRequest.Model != tt.wantModel {
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
}
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
}
})
}
}

View File

@@ -45,6 +45,10 @@ func ParserForName(name string) Parser {
var p Parser
switch name {
case "qwen3":
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
case "qwen3-thinking":
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
case "qwen3-coder":
p = &Qwen3CoderParser{}
case "qwen3-vl-instruct":

View File

@@ -54,6 +54,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
name string
}{
{"passthrough"},
{"qwen3"},
{"qwen3-thinking"},
{"qwen3-coder"},
{"harmony"},
}

335
model/parsers/qwen3.go Normal file
View File

@@ -0,0 +1,335 @@
package parsers
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type qwen3ParserState int
const (
qwen3ParserStateLookingForThinkingOpen qwen3ParserState = iota
qwen3ParserStateThinkingStartedEatingWhitespace
qwen3ParserStateCollectingThinking
qwen3ParserStateThinkingDoneEatingWhitespace
qwen3ParserStateCollectingContent
qwen3ParserStateToolStartedEatingWhitespace
qwen3ParserStateCollectingToolContent
)
const (
qwen3ThinkingOpenTag = "<think>"
qwen3ThinkingCloseTag = "</think>"
qwen3ToolOpenTag = "<tool_call>"
qwen3ToolCloseTag = "</tool_call>"
)
// Qwen3Parser parses Qwen3 output to extract thinking and tool calls.
// Qwen3 prompts end with <think> when thinking is enabled, so output begins
// with thinking content directly (without an opening tag).
type Qwen3Parser struct {
state qwen3ParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
defaultThinking bool
maybeThinkingOpenAtBOL bool
}
func (p *Qwen3Parser) HasToolSupport() bool {
return true
}
func (p *Qwen3Parser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.buffer.Reset()
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
if thinkValue == nil {
thinkingEnabled = p.defaultThinking
}
if p.hasThinkingSupport && thinkingEnabled {
p.state = qwen3ParserStateCollectingThinking
p.maybeThinkingOpenAtBOL = true
} else {
p.state = qwen3ParserStateCollectingContent
p.maybeThinkingOpenAtBOL = false
}
return tools
}
type qwen3Event interface {
isQwen3Event()
}
type qwen3EventContent struct {
content string
}
func (qwen3EventContent) isQwen3Event() {}
type qwen3EventRawToolCall struct {
raw string
}
func (qwen3EventRawToolCall) isQwen3Event() {}
type qwen3EventThinkingContent struct {
content string
}
func (qwen3EventThinkingContent) isQwen3Event() {}
func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case qwen3EventRawToolCall:
toolCall, err := parseQwen3ToolCall(event, p.tools)
if err != nil {
slog.Warn("qwen3 tool call parsing failed", "error", err)
return "", "", nil, err
}
calls = append(calls, toolCall)
case qwen3EventThinkingContent:
thinkingSb.WriteString(event.content)
case qwen3EventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), calls, nil
}
func (p *Qwen3Parser) parseEvents() []qwen3Event {
var all []qwen3Event
keepLooping := true
for keepLooping {
var events []qwen3Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
func (p *Qwen3Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen3ParserState) ([]qwen3Event, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true
}
func (p *Qwen3Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
return splitAtTag(&p.buffer, tag, trimAfter)
}
func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
var events []qwen3Event
switch p.state {
case qwen3ParserStateLookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = qwen3ParserStateThinkingStartedEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingThinking
}
return events, true
} else if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
return events, false
} else if trimmed == "" {
return events, false
}
p.state = qwen3ParserStateCollectingContent
return events, true
case qwen3ParserStateThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingThinking)
case qwen3ParserStateCollectingThinking:
acc := p.buffer.String()
// Some qwen3 checkpoints emit an explicit opening <think> tag even
// though the prompt already ended with <think>. Strip exactly one
// leading opening tag if present.
if p.maybeThinkingOpenAtBOL {
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
return events, false
}
p.maybeThinkingOpenAtBOL = false
return events, true
}
if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
return events, false
}
p.maybeThinkingOpenAtBOL = false
}
if strings.Contains(acc, qwen3ThinkingCloseTag) {
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, qwen3EventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = qwen3ParserStateThinkingDoneEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingContent
}
return events, true
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventThinkingContent{content: unambiguous})
}
return events, false
}
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventThinkingContent{content: unambiguous})
}
return events, false
case qwen3ParserStateThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingContent)
case qwen3ParserStateCollectingContent:
acc := p.buffer.String()
if strings.Contains(acc, qwen3ToolOpenTag) {
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
if len(before) > 0 {
events = append(events, qwen3EventContent{content: before})
}
if after == "" {
p.state = qwen3ParserStateToolStartedEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingToolContent
}
return events, true
} else if overlapLen := overlap(acc, qwen3ToolOpenTag); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventContent{content: unambiguous})
}
return events, false
}
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventContent{content: unambiguous})
}
return events, false
case qwen3ParserStateToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingToolContent)
case qwen3ParserStateCollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, qwen3ToolCloseTag) {
toolContent, _ := p.splitAtTag(qwen3ToolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("qwen3 tool call closing tag found but no content before it")
}
events = append(events, qwen3EventRawToolCall{raw: toolContent})
p.state = qwen3ParserStateCollectingContent
return events, true
}
return events, false
default:
panic("unreachable")
}
}
func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
var parsed struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse JSON: %w", err)
}
if parsed.Name == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
_ = tools // qwen3 uses direct JSON args and does not require schema coercion here.
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for key, value := range parsed.Arguments {
toolCall.Function.Arguments.Set(key, value)
}
return toolCall, nil
}

147
model/parsers/qwen3_test.go Normal file
View File

@@ -0,0 +1,147 @@
package parsers
import (
"testing"
"github.com/ollama/ollama/api"
)
func TestQwen3ParserThinkingEnabled(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("Let me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingEnabledWithExplicitOpeningTag(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingEnabledWithSplitOpeningTag(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("<thi", false)
if err != nil {
t.Fatalf("parse failed on first chunk: %v", err)
}
if content != "" || thinking != "" || len(calls) != 0 {
t.Fatalf("expected no output for first chunk, got content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
content, thinking, calls, err = parser.Add("nk>Let me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed on second chunk: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingDisabled(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add("Direct answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking, got %q", thinking)
}
if content != "Direct answer" {
t.Fatalf("expected content %q, got %q", "Direct answer", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserNilThinkDefaultsToContentForInstructParser(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, nil)
content, thinking, calls, err := parser.Add("Direct answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking, got %q", thinking)
}
if content != "Direct answer" {
t.Fatalf("expected content %q, got %q", "Direct answer", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserToolCall(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
input := "<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
content, thinking, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
location, ok := calls[0].Function.Arguments.Get("location")
if !ok || location != "San Francisco" {
t.Fatalf("expected location %q, got %v", "San Francisco", location)
}
unit, ok := calls[0].Function.Arguments.Get("unit")
if !ok || unit != "celsius" {
t.Fatalf("expected unit %q, got %v", "celsius", unit)
}
}

View File

@@ -2,6 +2,10 @@
# This script installs Ollama on Linux and macOS.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
# Wrap script in main function so that a truncated partial download doesn't end
# up executing half a script.
main() {
set -eu
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
@@ -446,3 +450,6 @@ fi
status "NVIDIA GPU ready."
install_success
}
main

View File

@@ -2371,30 +2371,6 @@ func TestImageGenerateStreamFalse(t *testing.T) {
return nil
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
"": {
llama: &mock,
Options: &opts,
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
isImagegen: true,
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
// Create model manifest with image capability
n := model.ParseName("test-image")
cfg := model.ConfigV2{Capabilities: []string{"image"}}
@@ -2410,6 +2386,35 @@ func TestImageGenerateStreamFalse(t *testing.T) {
t.Fatal(err)
}
loadedModel, err := GetModel("test-image")
if err != nil {
t.Fatal(err)
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
schedulerModelKey(loadedModel): {
llama: &mock,
Options: &opts,
model: loadedModel,
isImagegen: true,
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
streamFalse := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image",

View File

@@ -83,6 +83,28 @@ func InitScheduler(ctx context.Context) *Scheduler {
return sched
}
// schedulerModelKey returns the scheduler map key for a model.
// GGUF-backed models use ModelPath; safetensors/image models without a
// ModelPath use manifest digest so distinct models don't collide.
func schedulerModelKey(m *Model) string {
if m == nil {
return ""
}
if m.ModelPath != "" {
return m.ModelPath
}
if m.Digest != "" {
return "digest:" + m.Digest
}
if m.Name != "" {
return "name:" + m.Name
}
if m.ShortName != "" {
return "short:" + m.ShortName
}
return ""
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
@@ -104,8 +126,9 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
useImagegen: useImagegen,
}
key := schedulerModelKey(req.model)
s.loadedMu.Lock()
runner := s.loaded[req.model.ModelPath]
runner := s.loaded[key]
s.loadedMu.Unlock()
if runner != nil && !runner.needsReload(c, req) {
req.useLoadedRunner(runner, s.finishedReqCh)
@@ -151,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
for {
var runnerToExpire *runnerRef
pendingKey := schedulerModelKey(pending.model)
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
runner := s.loaded[pendingKey]
loadedCount := len(s.loaded)
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
for _, r := range s.loaded {
@@ -166,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
runnerToExpire = runner
} else {
// Runner is usable, return it
logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath)
logutil.Trace("using existing loaded runner", "model", pendingKey)
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
@@ -292,11 +316,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
finishedKey := schedulerModelKey(finished.model)
s.loadedMu.Lock()
runner := s.loaded[finished.model.ModelPath]
runner := s.loaded[finishedKey]
s.loadedMu.Unlock()
if runner == nil {
slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath)
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
continue
}
runner.refMu.Lock()
@@ -347,7 +372,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
s.loadedMu.Lock()
slog.Debug("got lock to unload expired event", "runner", runner)
runnerToUnload := s.loaded[runner.modelPath]
runnerToUnload := s.loaded[runner.modelKey]
if runnerToUnload == nil {
// If runnerToUnload is nil, we already processed an event and
// unloaded it. This double unload can happen if the initial
@@ -376,7 +401,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
}
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
runner.unload()
delete(s.loaded, runner.modelPath)
delete(s.loaded, runner.modelKey)
s.loadedMu.Unlock()
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
<-finished
@@ -514,6 +539,7 @@ iGPUScan:
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: llama,
Options: &req.opts,
sessionDuration: sessionDuration,
@@ -528,7 +554,7 @@ iGPUScan:
runner.refMu.Lock() // hold lock until running or aborted
s.loadedMu.Lock()
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
// Shouldn't happen, but safeguard against leaking a runner
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
oldRunner.refMu.Lock()
@@ -536,7 +562,7 @@ iGPUScan:
oldRunner.refMu.Unlock()
}
s.activeLoading = nil
s.loaded[req.model.ModelPath] = runner
s.loaded[runner.modelKey] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
@@ -596,6 +622,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: server,
Options: &req.opts,
loading: false,
@@ -606,7 +633,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loaded[runner.modelKey] = runner
s.loadedMu.Unlock()
// Set up expiration timer
@@ -684,6 +711,7 @@ type runnerRef struct {
model *Model
modelPath string
modelKey string
numParallel int
*api.Options
}
@@ -703,7 +731,7 @@ func (runner *runnerRef) unload() {
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
runner.refMu.Lock()
defer runner.refMu.Unlock()
@@ -814,6 +842,10 @@ func (runner *runnerRef) LogValue() slog.Value {
if runner == nil {
return slog.StringValue("nil")
}
modelID := runner.modelPath
if modelID == "" {
modelID = runner.modelKey
}
attrs := []slog.Attr{}
if runner.model != nil {
attrs = append(attrs, slog.String("name", runner.model.Name))
@@ -828,7 +860,7 @@ func (runner *runnerRef) LogValue() slog.Value {
slog.String("vram", format.HumanBytes2(runner.vramSize)),
slog.Int("parallel", runner.numParallel),
slog.Int("pid", runner.pid),
slog.String("model", runner.modelPath),
slog.String("model", modelID),
)
if runner.Options != nil {
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
@@ -873,8 +905,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
if d1 != d2 {
return d1 < d2
}
// Secondary sort by model path lex order
return a[i].modelPath < a[j].modelPath
// Secondary sort by model key/path lex order
n1 := a[i].modelPath
if n1 == "" {
n1 = a[i].modelKey
}
n2 := a[j].modelPath
if n2 == "" {
n2 = a[j].modelKey
}
return n1 < n2
}
// TODO - future consideration to pick runners based on size
@@ -934,8 +974,9 @@ func (s *Scheduler) unloadAllRunners() {
}
func (s *Scheduler) expireRunner(model *Model) {
modelKey := schedulerModelKey(model)
s.loadedMu.Lock()
runner, ok := s.loaded[model.ModelPath]
runner, ok := s.loaded[modelKey]
s.loadedMu.Unlock()
if ok {
runner.refMu.Lock()

View File

@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
b.ctxDone()
}
func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
require.Empty(t, successCh)
require.Empty(t, errCh)
require.Len(t, s.pendingReqCh, 1)
}
func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqCtx, cancelReq := context.WithCancel(ctx)
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
cancelReq()
select {
case runner := <-successCh:
require.Equal(t, loadedRunner, runner)
default:
t.Fatal("expected existing runner to be reused")
}
require.Empty(t, errCh)
require.Empty(t, s.pendingReqCh)
}
func TestSchedExpireRunner(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
defer done()

View File

@@ -30,6 +30,8 @@ type ModelfileConfig struct {
Template string
System string
License string
Parser string
Renderer string
}
// CreateOptions holds all options for model creation.
@@ -37,7 +39,7 @@ type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
}
// CreateModel imports a model from a local directory.
@@ -267,8 +269,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
ModelFormat: "safetensors",
Capabilities: caps,
Requires: MinOllamaVersion,
Parser: parserName,
Renderer: rendererName,
Parser: resolveParserName(opts.Modelfile, parserName),
Renderer: resolveRendererName(opts.Modelfile, rendererName),
}
configJSON, err := json.Marshal(configData)
if err != nil {
@@ -305,6 +307,22 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
}
}
func resolveParserName(mf *ModelfileConfig, inferred string) string {
if mf != nil && mf.Parser != "" {
return mf.Parser
}
return inferred
}
func resolveRendererName(mf *ModelfileConfig, inferred string) string {
if mf != nil && mf.Renderer != "" {
return mf.Renderer
}
return inferred
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
var layers []manifest.Layer
@@ -410,7 +428,7 @@ func getParserName(modelDir string) string {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3-coder"
return "qwen3"
}
}
@@ -424,7 +442,7 @@ func getParserName(modelDir string) string {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
return "qwen3"
}
}

View File

@@ -10,6 +10,8 @@ func TestModelfileConfig(t *testing.T) {
Template: "{{ .Prompt }}",
System: "You are a helpful assistant.",
License: "MIT",
Parser: "qwen3",
Renderer: "qwen3",
}
if config.Template != "{{ .Prompt }}" {
@@ -21,6 +23,12 @@ func TestModelfileConfig(t *testing.T) {
if config.License != "MIT" {
t.Errorf("License = %q, want %q", config.License, "MIT")
}
if config.Parser != "qwen3" {
t.Errorf("Parser = %q, want %q", config.Parser, "qwen3")
}
if config.Renderer != "qwen3" {
t.Errorf("Renderer = %q, want %q", config.Renderer, "qwen3")
}
}
func TestModelfileConfig_Empty(t *testing.T) {
@@ -35,6 +43,12 @@ func TestModelfileConfig_Empty(t *testing.T) {
if config.License != "" {
t.Errorf("License should be empty, got %q", config.License)
}
if config.Parser != "" {
t.Errorf("Parser should be empty, got %q", config.Parser)
}
if config.Renderer != "" {
t.Errorf("Renderer should be empty, got %q", config.Renderer)
}
}
func TestModelfileConfig_PartialFields(t *testing.T) {
@@ -53,6 +67,12 @@ func TestModelfileConfig_PartialFields(t *testing.T) {
if config.License != "" {
t.Error("License should be empty")
}
if config.Parser != "" {
t.Error("Parser should be empty")
}
if config.Renderer != "" {
t.Error("Renderer should be empty")
}
}
func TestMinOllamaVersion(t *testing.T) {
@@ -98,6 +118,8 @@ func TestCreateOptions(t *testing.T) {
Template: "test",
System: "system",
License: "MIT",
Parser: "qwen3-thinking",
Renderer: "qwen3",
},
}
@@ -116,6 +138,92 @@ func TestCreateOptions(t *testing.T) {
if opts.Modelfile.Template != "test" {
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
}
if opts.Modelfile.Parser != "qwen3-thinking" {
t.Errorf("Modelfile.Parser = %q, want %q", opts.Modelfile.Parser, "qwen3-thinking")
}
if opts.Modelfile.Renderer != "qwen3" {
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
}
}
func TestResolveParserName(t *testing.T) {
tests := []struct {
name string
mf *ModelfileConfig
inferred string
want string
}{
{
name: "nil modelfile uses inferred",
mf: nil,
inferred: "qwen3",
want: "qwen3",
},
{
name: "empty parser uses inferred",
mf: &ModelfileConfig{
Parser: "",
},
inferred: "qwen3",
want: "qwen3",
},
{
name: "explicit parser overrides inferred",
mf: &ModelfileConfig{
Parser: "qwen3-thinking",
},
inferred: "qwen3",
want: "qwen3-thinking",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveParserName(tt.mf, tt.inferred); got != tt.want {
t.Fatalf("resolveParserName() = %q, want %q", got, tt.want)
}
})
}
}
func TestResolveRendererName(t *testing.T) {
tests := []struct {
name string
mf *ModelfileConfig
inferred string
want string
}{
{
name: "nil modelfile uses inferred",
mf: nil,
inferred: "qwen3-coder",
want: "qwen3-coder",
},
{
name: "empty renderer uses inferred",
mf: &ModelfileConfig{
Renderer: "",
},
inferred: "qwen3-coder",
want: "qwen3-coder",
},
{
name: "explicit renderer overrides inferred",
mf: &ModelfileConfig{
Renderer: "qwen3",
},
inferred: "qwen3-coder",
want: "qwen3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveRendererName(tt.mf, tt.inferred); got != tt.want {
t.Fatalf("resolveRendererName() = %q, want %q", got, tt.want)
}
})
}
}
func TestCreateOptions_Defaults(t *testing.T) {

View File

@@ -3,5 +3,8 @@
package mlxrunner
import (
_ "github.com/ollama/ollama/x/models/gemma3"
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"
)

View File

@@ -18,7 +18,9 @@
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
handle->ctx = (void*) DLOPEN(path);
CHECK(handle->ctx != NULL);
if (handle->ctx == NULL) {
return 1;
}
return 0;
}

View File

@@ -55,6 +55,30 @@ func tryLoadFromDir(dir string) bool {
return false
}
// tryLoadByName attempts to load the library using just its name,
// allowing the system to use rpath, LD_LIBRARY_PATH, or standard search paths.
// Returns true if the library was successfully loaded.
func tryLoadByName() bool {
libraryName := "libmlxc.dylib"
if runtime.GOOS == "linux" {
libraryName = "libmlxc.so"
}
cPath := C.CString(libraryName)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
return false
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
C.mlx_dynamic_unload(&handle)
return false
}
return true
}
func init() {
switch runtime.GOOS {
case "darwin":
@@ -73,6 +97,11 @@ func init() {
}
}
// Try loading via rpath/standard library search
if tryLoadByName() {
return
}
// Build search paths: executable directory, then build directories
var searchDirs []string
if exe, err := os.Executable(); err == nil {

View File

@@ -8,10 +8,10 @@ import (
"log/slog"
"sync"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/tokenizer"
)
// Model is the interface that model implementations must satisfy.

View File

@@ -0,0 +1,92 @@
//go:build mlx
package model
import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
type LinearFactory struct {
tensors map[string]*mlx.Array
defaultGroupSize int
defaultBits int
defaultMode string
tensorQuant map[string]*TensorQuantInfo
}
// NewLinearFactory creates a reusable constructor for model linear layers.
func NewLinearFactory(
tensors map[string]*mlx.Array,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) LinearFactory {
return LinearFactory{
tensors: tensors,
defaultGroupSize: defaultGroupSize,
defaultBits: defaultBits,
defaultMode: defaultMode,
tensorQuant: tensorQuant,
}
}
// Make constructs a linear layer at path.
func (f LinearFactory) Make(path string) nn.LinearLayer {
return MakeLinearLayer(
f.tensors,
path,
f.defaultGroupSize,
f.defaultBits,
f.defaultMode,
f.tensorQuant,
)
}
// MakeLinearLayer constructs a linear layer from a tensor map.
//
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
// quant params via TensorQuant metadata (with shape-based affine fallback).
// For non-quantized tensors, it returns a standard nn.Linear.
func MakeLinearLayer(
tensors map[string]*mlx.Array,
path string,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) nn.LinearLayer {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
bias := tensors[path+".bias"]
groupSize, bits, mode := ResolveLinearQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
path+".weight",
w,
scales,
)
return &nn.QuantizedLinear{
Weight: w,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
bias := tensors[path+".bias"]
return nn.NewLinear(w, bias)
}

130
x/mlxrunner/model/quant.go Normal file
View File

@@ -0,0 +1,130 @@
//go:build mlx
package model
import (
"strings"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "NVFP4":
return 16, 4, "nvfp4"
case "FP4", "Q4", "INT4":
return 32, 4, "affine"
case "MXFP8":
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
return 64, 8, "affine"
default:
return 32, 8, "affine"
}
}
// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
// when available, otherwise falling back to the provided model defaults.
func TensorQuantParams(
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
tensorName string,
) (groupSize, bits int, mode string, fromTensor bool) {
if tensorQuant != nil {
if tq := tensorQuant[tensorName]; tq != nil {
groupSize, bits, mode = QuantizationParams(tq.QuantType)
if tq.GroupSize > 0 {
groupSize = tq.GroupSize
}
return groupSize, bits, mode, true
}
}
return defaultGroupSize, defaultBits, defaultMode, false
}
// ResolveLinearQuantParams resolves quantization params for a quantized linear
// tensor, preferring per-tensor metadata and falling back to shape-based
// inference for affine packed tensors.
func ResolveLinearQuantParams(
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
tensorName string,
weight, scales *mlx.Array,
) (groupSize, bits int, mode string) {
groupSize, bits, mode, fromTensor := TensorQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
tensorName,
)
if mode == "affine" {
if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
if !fromTensor || groupSize == 0 || bits == 0 {
groupSize = inferredGroupSize
bits = inferredBits
}
}
}
return groupSize, bits, mode
}
// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
// tensors from packed weight and scale shapes.
func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
if weight == nil || scales == nil {
return 0, 0, false
}
weightShape := weight.Dims()
scaleShape := scales.Dims()
if len(weightShape) == 0 || len(scaleShape) == 0 {
return 0, 0, false
}
weightCols := weightShape[len(weightShape)-1]
scalesCols := scaleShape[len(scaleShape)-1]
if weightCols <= 0 || scalesCols <= 0 {
return 0, 0, false
}
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
switch {
case groupSize4 == 32:
return 32, 4, true
case groupSize8 == 64:
return 64, 8, true
case groupSize4 == 64 && groupSize8 == 32:
if hintBits == 8 {
return 32, 8, true
}
if hintBits == 4 {
return 64, 4, true
}
}
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
return groupSize4, 4, true
}
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
return groupSize8, 8, true
}
return 0, 0, false
}
func isCommonGroupSize(v int) bool {
switch v {
case 16, 32, 64, 128:
return true
default:
return false
}
}

View File

@@ -8,42 +8,63 @@ import (
"fmt"
"io"
"os"
"sort"
"strconv"
"strings"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// Root wraps a ModelManifest with pre-scanned quantization metadata.
type Root struct {
Manifest *manifest.ModelManifest
quantType string
groupSize int
// TensorQuantInfo describes per-tensor quantization metadata.
type TensorQuantInfo struct {
QuantType string
GroupSize int
}
// Open loads a manifest for the given model name and pre-scans the first
// tensor blob for quantization metadata (quant_type, group_size).
// Root wraps a ModelManifest with pre-scanned quantization metadata.
type Root struct {
Manifest *manifest.ModelManifest
// Backwards-compatible model-level quant metadata (first tensor blob).
quantType string
groupSize int
// Per-tensor quantization metadata.
tensorQuant map[string]*TensorQuantInfo
}
// Open loads a manifest for the given model name and scans tensor blobs for
// quantization metadata.
func Open(modelName string) (*Root, error) {
m, err := manifest.LoadManifest(modelName)
if err != nil {
return nil, err
}
root := &Root{Manifest: m}
root := &Root{
Manifest: m,
tensorQuant: make(map[string]*TensorQuantInfo),
}
// Pre-scan first tensor blob for quantization metadata
for _, layer := range m.GetTensorLayers("") {
blobPath := m.BlobPath(layer.Digest)
meta, err := readBlobMetadata(blobPath)
if err != nil || meta == nil {
infos, blobQuantType, blobGroupSize, err := readBlobTensorQuantInfo(blobPath)
if err != nil {
continue
}
if qt := meta["quant_type"]; qt != "" {
root.quantType = strings.ToUpper(qt)
for name, info := range infos {
root.tensorQuant[name] = info
}
if gs := meta["group_size"]; gs != "" {
fmt.Sscanf(gs, "%d", &root.groupSize)
if root.quantType == "" && blobQuantType != "" {
root.quantType = strings.ToUpper(blobQuantType)
root.groupSize = blobGroupSize
if root.groupSize == 0 {
root.groupSize = defaultGroupSize(root.quantType)
}
}
break // only check the first tensor blob
}
return root, nil
@@ -52,46 +73,180 @@ func Open(modelName string) (*Root, error) {
// Close is a no-op for now (future: release resources).
func (r *Root) Close() {}
// QuantType returns the quantization type detected from tensor metadata.
// QuantType returns the quantization type detected from the first tensor blob metadata.
func (r *Root) QuantType() string { return r.quantType }
// GroupSize returns the quantization group size detected from tensor metadata.
// GroupSize returns the quantization group size detected from the first tensor blob metadata.
func (r *Root) GroupSize() int { return r.groupSize }
// readBlobMetadata reads the __metadata__ from a safetensors blob header.
func readBlobMetadata(path string) (map[string]string, error) {
// TensorQuant returns per-tensor quantization metadata if available.
func (r *Root) TensorQuant(name string) *TensorQuantInfo {
if r == nil {
return nil
}
return r.tensorQuant[name]
}
// AllTensorQuant returns a copy of the per-tensor quantization metadata.
func (r *Root) AllTensorQuant() map[string]*TensorQuantInfo {
out := make(map[string]*TensorQuantInfo, len(r.tensorQuant))
for k, v := range r.tensorQuant {
if v == nil {
continue
}
copy := *v
out[k] = &copy
}
return out
}
func defaultGroupSize(quantType string) int {
groupSize, _, _ := QuantizationParams(quantType)
return groupSize
}
func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, int, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
return nil, "", 0, err
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return nil, err
return nil, "", 0, err
}
if headerSize > 1024*1024 {
return nil, fmt.Errorf("header too large: %d", headerSize)
if headerSize > 100*1024*1024 {
return nil, "", 0, fmt.Errorf("header too large: %d", headerSize)
}
data := make([]byte, headerSize)
if _, err := io.ReadFull(f, data); err != nil {
return nil, err
return nil, "", 0, err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data, &header); err != nil {
return nil, err
return nil, "", 0, err
}
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
globalQuantType = strings.ToUpper(globalQuantType)
mainNames := mainTensorNames(header)
infos := make(map[string]*TensorQuantInfo)
for _, name := range mainNames {
if _, ok := header[name+".scale"]; !ok {
continue
}
quantType := globalQuantType
groupSize := globalGroupSize
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
if quantType == "" {
quantType = inferredType
}
if groupSize == 0 {
groupSize = inferredGroup
}
if quantType == "" {
continue
}
if groupSize == 0 {
groupSize = defaultGroupSize(quantType)
}
infos[name] = &TensorQuantInfo{QuantType: quantType, GroupSize: groupSize}
}
return infos, globalQuantType, globalGroupSize, nil
}
func parseGlobalQuantMetadata(header map[string]json.RawMessage) (quantType string, groupSize int) {
metaRaw, ok := header["__metadata__"]
if !ok {
return nil, nil
return "", 0
}
var meta map[string]string
if err := json.Unmarshal(metaRaw, &meta); err != nil {
return nil, err
return "", 0
}
return meta, nil
quantType = meta["quant_type"]
if gs := meta["group_size"]; gs != "" {
groupSize, _ = strconv.Atoi(gs)
}
return quantType, groupSize
}
func mainTensorNames(header map[string]json.RawMessage) []string {
names := make([]string, 0, len(header))
for name := range header {
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
continue
}
names = append(names, name)
}
sort.Strings(names)
return names
}
func inferQuantTypeFromShapes(header map[string]json.RawMessage, tensorName string, hintQuantType string) (string, int) {
type tensorShape struct {
Shape []int64 `json:"shape"`
}
mainRaw, ok := header[tensorName]
if !ok {
return "", 0
}
scaleRaw, ok := header[tensorName+".scale"]
if !ok {
return "", 0
}
var mainInfo tensorShape
if err := json.Unmarshal(mainRaw, &mainInfo); err != nil || len(mainInfo.Shape) == 0 {
return "", 0
}
var scaleInfo tensorShape
if err := json.Unmarshal(scaleRaw, &scaleInfo); err != nil || len(scaleInfo.Shape) == 0 {
return "", 0
}
weightCols := int(mainInfo.Shape[len(mainInfo.Shape)-1])
scalesCols := int(scaleInfo.Shape[len(scaleInfo.Shape)-1])
if weightCols <= 0 || scalesCols <= 0 {
return "", 0
}
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
switch {
case groupSize4 == 32:
return "INT4", 32
case groupSize8 == 64:
return "INT8", 64
case groupSize4 == 64 && groupSize8 == 32:
h := strings.ToUpper(hintQuantType)
if strings.Contains(h, "8") {
return "INT8", 32
}
if strings.Contains(h, "4") {
return "INT4", 64
}
}
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
return "INT4", groupSize4
}
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
return "INT8", groupSize8
}
return "", 0
}

View File

@@ -7,7 +7,6 @@ import (
"errors"
"log/slog"
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
@@ -18,15 +17,27 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return errors.New("model not loaded")
}
mlx.EnableCompile()
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
}
if enableCompile {
mlx.EnableCompile()
} else {
mlx.DisableCompile()
}
inputs := r.Tokenizer.Encode(request.Prompt, true)
caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 {
caches = make([]cache.Cache, r.Model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
caches = cacheFactory.NewCaches()
} else {
caches = make([]cache.Cache, r.Model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
}
}
@@ -114,13 +125,5 @@ func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
return ""
}
if text := b.String(); utf8.ValidString(text) {
b.Reset()
return text
} else if b.Len() >= utf8.UTFMax {
b.Reset()
return text
}
return ""
return flushValidUTF8Prefix(b)
}

View File

@@ -12,12 +12,12 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
)
type Request struct {

View File

@@ -0,0 +1,47 @@
package mlxrunner
import (
"bytes"
"unicode/utf8"
)
// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix
// currently buffered, leaving any incomplete trailing bytes in place.
func flushValidUTF8Prefix(b *bytes.Buffer) string {
data := b.Bytes()
if len(data) == 0 {
return ""
}
prefix := validUTF8PrefixLen(data)
if prefix == 0 {
return ""
}
text := string(data[:prefix])
b.Next(prefix)
return text
}
func validUTF8PrefixLen(data []byte) int {
i := 0
prefix := 0
for i < len(data) {
r, size := utf8.DecodeRune(data[i:])
if r == utf8.RuneError && size == 1 {
if !utf8.FullRune(data[i:]) {
break
}
// Invalid UTF-8 byte; consume one byte to guarantee forward progress.
i++
prefix = i
continue
}
i += size
prefix = i
}
return prefix
}

View File

@@ -0,0 +1,46 @@
package mlxrunner
import (
"bytes"
"testing"
)
func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) {
var b bytes.Buffer
b.Write([]byte{0xE3, 0x81})
if got := flushValidUTF8Prefix(&b); got != "" {
t.Fatalf("first flush = %q, want empty", got)
}
b.Write([]byte{0x93, 0xE3})
if got := flushValidUTF8Prefix(&b); got != "こ" {
t.Fatalf("second flush = %q, want %q", got, "こ")
}
if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) {
t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3})
}
b.Write([]byte{0x82, 0x93})
if got := flushValidUTF8Prefix(&b); got != "ん" {
t.Fatalf("third flush = %q, want %q", got, "ん")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after third flush: %d", b.Len())
}
}
func TestFlushValidUTF8Prefix_ValidText(t *testing.T) {
var b bytes.Buffer
b.WriteString("hello 世界")
if got := flushValidUTF8Prefix(&b); got != "hello 世界" {
t.Fatalf("flush = %q, want %q", got, "hello 世界")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after flush: %d", b.Len())
}
}

521
x/models/gemma3/gemma3.go Normal file
View File

@@ -0,0 +1,521 @@
//go:build mlx
// Package gemma3 provides the Gemma 3 text model implementation for MLX.
package gemma3
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
base.Register("Gemma3ForCausalLM", newModel)
base.Register("Gemma3ForConditionalGeneration", newModel)
}
// TextConfig holds configuration for the Gemma 3 text model.
type TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
LayerTypes []string `json:"layer_types"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
Scale float32 `json:"-"`
}
// Attention implements Gemma 3 attention with Q/K normalization.
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
KNorm *nn.RMSNorm
// Precomputed (1 + weight) for Gemma-style RMSNorm.
QNormScaled *mlx.Array
KNormScaled *mlx.Array
}
// MLP is the feed-forward network with GELU activation.
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
// DecoderLayer is a single transformer block.
type DecoderLayer struct {
InputNorm *nn.RMSNorm
Attention *Attention
PostAttnNorm *nn.RMSNorm
PreFFNorm *nn.RMSNorm
MLP *MLP
PostFFNorm *nn.RMSNorm
// Precomputed (1 + weight) for Gemma-style RMSNorm.
InputNormScaled *mlx.Array
PostAttnNormScaled *mlx.Array
PreFFNormScaled *mlx.Array
PostFFNormScaled *mlx.Array
// Layer metadata.
IsSliding bool
LayerIdx int32
}
// Model is the Gemma 3 text-only model.
type Model struct {
EmbedTokens *nn.Embedding
Layers []*DecoderLayer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
// Precomputed (1 + weight) for Gemma-style RMSNorm.
NormScaled *mlx.Array
tok *tokenizer.Tokenizer
*TextConfig
weightPrefix string
}
func defaultHeads(numLayers int32) (numHeads, numKVHeads int32) {
switch numLayers {
case 34:
return 8, 4
case 48:
return 16, 8
case 62:
return 32, 16
default:
return 8, 4
}
}
func parseTextConfig(configData []byte) (TextConfig, bool, error) {
var cfg TextConfig
if err := json.Unmarshal(configData, &cfg); err != nil {
return TextConfig{}, false, fmt.Errorf("parse config: %w", err)
}
var wrapped struct {
TextConfig *TextConfig `json:"text_config"`
}
if err := json.Unmarshal(configData, &wrapped); err != nil {
return TextConfig{}, false, fmt.Errorf("parse nested text config: %w", err)
}
fromConditional := wrapped.TextConfig != nil
if fromConditional {
cfg = *wrapped.TextConfig
if cfg.HeadDim == 0 {
cfg.HeadDim = 256
}
if cfg.NumAttentionHeads == 0 {
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.NumKeyValueHeads == 0 {
_, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 262208
}
if cfg.SlidingWindowPattern == 0 && len(cfg.LayerTypes) == 0 {
cfg.SlidingWindowPattern = 6
}
if cfg.MaxPositionEmbeddings == 0 {
cfg.MaxPositionEmbeddings = 131072
}
}
if cfg.HeadDim == 0 {
cfg.HeadDim = 256
}
if cfg.NumAttentionHeads == 0 {
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.NumKeyValueHeads == 0 {
cfg.NumKeyValueHeads = max(1, cfg.NumAttentionHeads/2)
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RopeLocalBaseFreq == 0 {
cfg.RopeLocalBaseFreq = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 262208
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
return cfg, fromConditional, nil
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func isLayerSliding(layerIdx int32, cfg *TextConfig) bool {
if len(cfg.LayerTypes) > 0 && int(layerIdx) < len(cfg.LayerTypes) {
return cfg.LayerTypes[layerIdx] == "sliding_attention"
}
if cfg.SlidingWindowPattern <= 0 {
return false
}
return (layerIdx+1)%cfg.SlidingWindowPattern != 0
}
func precomputeGemmaScaledWeights(m *Model) {
if m.Norm != nil {
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
}
var scaled []*mlx.Array
if m.NormScaled != nil {
scaled = append(scaled, m.NormScaled)
}
for _, layer := range m.Layers {
if layer == nil || layer.Attention == nil {
continue
}
if layer.InputNorm != nil {
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
scaled = append(scaled, layer.InputNormScaled)
}
if layer.PostAttnNorm != nil {
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
scaled = append(scaled, layer.PostAttnNormScaled)
}
if layer.PreFFNorm != nil {
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
scaled = append(scaled, layer.PreFFNormScaled)
}
if layer.PostFFNorm != nil {
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
scaled = append(scaled, layer.PostFFNormScaled)
}
if layer.Attention.QNorm != nil {
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
scaled = append(scaled, layer.Attention.QNormScaled)
}
if layer.Attention.KNorm != nil {
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
scaled = append(scaled, layer.Attention.KNormScaled)
}
}
if len(scaled) > 0 {
mlx.Eval(scaled...)
}
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
cfg, _, err := parseTextConfig(configData)
if err != nil {
return nil, err
}
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
TextConfig: &cfg,
tok: tok,
}
for i := range m.Layers {
m.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), m.TextConfig),
}
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Gemma usually ties output projection to embeddings.
m.LMHead = nn.NewLinear(embedWeight, nil)
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &DecoderLayer{
LayerIdx: i,
IsSliding: isLayerSliding(i, m.TextConfig),
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.PostAttnNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil {
layer.PreFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil {
layer.PostFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.InputNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.PostAttnNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.PreFFNorm == nil {
return fmt.Errorf("layer %d: missing pre_feedforward_layernorm", i)
}
if layer.PostFFNorm == nil {
return fmt.Errorf("layer %d: missing post_feedforward_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
return fmt.Errorf("layer %d: missing attention q/k norms", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
precomputeGemmaScaledWeights(m)
if m.NormScaled == nil {
return fmt.Errorf("missing precomputed final norm weight")
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.TextConfig)
}
return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
// NewCaches creates cache objects for all layers.
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i, layer := range m.Layers {
if m.SlidingWindow > 0 && layer.IsSliding {
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
// FormatPrompt applies the Gemma 3 chat template.
func (m *Model) FormatPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps)
mlpOut := l.MLP.Forward(normed)
mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return mlx.Add(h, mlpOut)
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Transpose(v, 0, 2, 1, 3)
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps)
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
if c != nil {
k, v = c.Update(k, v)
}
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.GELUApprox(m.GateProj.Forward(x))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
}

View File

@@ -8,14 +8,13 @@ import (
"encoding/json"
"fmt"
"math"
"strings"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
@@ -64,9 +63,10 @@ type Config struct {
RopeScaling *RopeScaling `json:"rope_scaling"`
// Quantization parameters (set during load based on model quantization)
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
QuantBits int `json:"-"` // Bits per weight (4 or 8)
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
QuantBits int `json:"-"` // Bits per weight (4 or 8)
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
@@ -372,22 +372,6 @@ func supportsGatherQMM(mode string, bits int) bool {
return mode == "affine" && (bits == 4 || bits == 8)
}
// quantizationParams returns groupSize, bits, mode for a quantization type string.
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "NVFP4":
return 16, 4, "nvfp4"
case "FP4", "Q4", "INT4":
return 32, 4, "affine"
case "MXFP8":
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
return 64, 8, "affine"
default:
return 32, 8, "affine"
}
}
// ExpertWeight holds a single expert's weight with optional quantization components.
type ExpertWeight struct {
Weight *mlx.Array
@@ -408,7 +392,15 @@ func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized b
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
groupSize, bits, mode := cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode
groupSize, bits, mode := model.ResolveLinearQuantParams(
cfg.QuantGroupSize,
cfg.QuantBits,
cfg.QuantMode,
cfg.TensorQuant,
path+".weight",
w,
scales,
)
if useQuantized && supportsGatherQMM(mode, bits) {
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
@@ -492,7 +484,16 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
// Check if quantized and dequantize
if scales := tensors[path+".weight_scale"]; scales != nil {
qbiases := tensors[path+".weight_qbias"]
w = mlx.Dequantize(w, scales, qbiases, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode)
groupSize, bits, mode := model.ResolveLinearQuantParams(
cfg.QuantGroupSize,
cfg.QuantBits,
cfg.QuantMode,
cfg.TensorQuant,
path+".weight",
w,
scales,
)
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
}
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
@@ -507,32 +508,6 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
return embedQ, unembedOut
}
// makeLinear creates a Linear or QuantizedLinear layer from the tensor map.
func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.LinearLayer {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
bias := tensors[path+".bias"]
return &nn.QuantizedLinear{
Weight: w,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: cfg.QuantGroupSize,
Bits: cfg.QuantBits,
Mode: cfg.QuantMode,
}
}
bias := tensors[path+".bias"]
return nn.NewLinear(w, bias)
}
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
// no weights loaded yet). Called by the registry via base.New().
func newModel(root *model.Root) (base.Model, error) {
@@ -551,13 +526,14 @@ func newModel(root *model.Root) (base.Model, error) {
// Set up quantization parameters from pre-scanned metadata
if qt := root.QuantType(); qt != "" {
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(qt)
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
} else {
cfg.QuantGroupSize, _, _ = quantizationParams(qt)
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
// Load tokenizer
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
@@ -596,7 +572,20 @@ func newModel(root *model.Root) (base.Model, error) {
// layer creation.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
cfg := m.Config
linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
if !useQuantized && cfg.TensorQuant != nil {
for _, tq := range cfg.TensorQuant {
if tq == nil {
continue
}
_, bits, mode := model.QuantizationParams(tq.QuantType)
if supportsGatherQMM(mode, bits) {
useQuantized = true
break
}
}
}
// Load embedding
if w := tensors["model.embed_tokens.weight"]; w != nil {
@@ -609,7 +598,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
// Load LM head
m.LMHead = makeLinear(tensors, "lm_head", cfg)
m.LMHead = linears.Make("lm_head")
// Load layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
@@ -617,16 +606,16 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
// Load attention (same for both block types)
attn := &MLAAttention{}
attn.QAProj = makeLinear(tensors, prefix+".self_attn.q_a_proj", cfg)
attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj")
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
attn.QBProj = makeLinear(tensors, prefix+".self_attn.q_b_proj", cfg)
attn.KVAProjWithMQA = makeLinear(tensors, prefix+".self_attn.kv_a_proj_with_mqa", cfg)
attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj")
attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa")
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
attn.OProj = makeLinear(tensors, prefix+".self_attn.o_proj", cfg)
attn.OProj = linears.Make(prefix + ".self_attn.o_proj")
// Sanitize MLA weights for absorbed attention
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
@@ -647,9 +636,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
block.MLP = &DenseMLP{
GateProj: makeLinear(tensors, prefix+".mlp.gate_proj", cfg),
UpProj: makeLinear(tensors, prefix+".mlp.up_proj", cfg),
DownProj: makeLinear(tensors, prefix+".mlp.down_proj", cfg),
GateProj: linears.Make(prefix + ".mlp.gate_proj"),
UpProj: linears.Make(prefix + ".mlp.up_proj"),
DownProj: linears.Make(prefix + ".mlp.down_proj"),
}
m.Layers[i] = block
@@ -690,7 +679,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
moeGate := &MoEGate{}
moeGate.Gate = makeLinear(tensors, prefix+".mlp.gate", cfg)
moeGate.Gate = linears.Make(prefix + ".mlp.gate")
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
moeGate.EScoreCorrectionBias = bias
}
@@ -703,9 +692,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
// Load shared experts if present
if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{
GateProj: makeLinear(tensors, prefix+".mlp.shared_experts.gate_proj", cfg),
UpProj: makeLinear(tensors, prefix+".mlp.shared_experts.up_proj", cfg),
DownProj: makeLinear(tensors, prefix+".mlp.shared_experts.down_proj", cfg),
GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"),
UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"),
DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"),
}
}

323
x/models/llama/llama.go Normal file
View File

@@ -0,0 +1,323 @@
//go:build mlx
// Package llama provides a Llama-style decoder-only transformer for MLX.
package llama
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
base.Register("LlamaForCausalLM", newModel)
}
// Config holds Llama model configuration.
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
HeadDim int32 `json:"-"`
Scale float32 `json:"-"`
}
// Model is a Llama text model.
type Model struct {
EmbedTokens *nn.Embedding
Layers []*Layer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
tok *tokenizer.Tokenizer
*Config
weightPrefix string
}
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm
MLPNorm *nn.RMSNorm
}
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
}
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if cfg.HiddenSize <= 0 {
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
}
if cfg.NumAttentionHeads <= 0 {
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
}
if cfg.NumKeyValueHeads <= 0 {
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
}
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
}
if cfg.HeadDim == 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
if cfg.HeadDim <= 0 {
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
}
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-5
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData,
}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if m.TieWordEmbeddings {
m.LMHead = nn.NewLinear(embedWeight, nil)
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Fallback used by many Llama checkpoints where output is tied.
m.LMHead = nn.NewLinear(embedWeight, nil)
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &Layer{
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.AttentionNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.MLPNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
if c != nil {
k, v = c.Update(k, v)
}
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
}

338
x/models/qwen3/qwen3.go Normal file
View File

@@ -0,0 +1,338 @@
//go:build mlx
// Package qwen3 provides the Qwen3 text model implementation for MLX.
package qwen3
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
base.Register("Qwen3ForCausalLM", newModel)
}
// Config holds Qwen3 model configuration.
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
Scale float32 `json:"-"`
QKNormEps float32 `json:"-"`
}
// Model is the Qwen3 text-only model.
type Model struct {
EmbedTokens *nn.Embedding
Layers []*Layer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
tok *tokenizer.Tokenizer
*Config
weightPrefix string
}
// Layer is a single Qwen3 decoder block.
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm
MLPNorm *nn.RMSNorm
}
// Attention implements Qwen3 attention with Q/K norms.
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
KNorm *nn.RMSNorm
}
// MLP is the feed-forward network with SwiGLU activation.
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if cfg.HiddenSize <= 0 {
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
}
if cfg.NumAttentionHeads <= 0 {
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
}
if cfg.NumKeyValueHeads <= 0 {
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
}
if cfg.HeadDim == 0 {
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
if cfg.HeadDim <= 0 {
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
}
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
cfg.QKNormEps = 1e-6
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData,
}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if m.TieWordEmbeddings {
m.LMHead = nn.NewLinear(embedWeight, nil)
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Qwen3 checkpoints commonly tie output projection to embeddings.
m.LMHead = nn.NewLinear(embedWeight, nil)
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &Layer{
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.QKNormEps)
}
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
layer.Attention.KNorm = nn.NewRMSNorm(w, m.QKNormEps)
}
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.AttentionNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.MLPNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
return fmt.Errorf("layer %d: missing attention q/k norms", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
q = a.QNorm.Forward(q, cfg.QKNormEps)
k = a.KNorm.Forward(k, cfg.QKNormEps)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
if c != nil {
k, v = c.Update(k, v)
}
// MLX SDPA supports grouped-query attention directly (Q heads can be a
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"os"
"sort"
"strings"
@@ -58,7 +59,15 @@ func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
}
}
return buildModelInfo(config, totalBytes, tensorCount), nil
info := buildModelInfo(config, totalBytes, tensorCount)
// For quantized models, byte-based estimation can significantly undercount
// parameters. Prefer exact counting from tensor shapes in safetensors headers.
if paramCount, err := getParameterCountFromManifest(mf); err == nil && paramCount > 0 {
info["general.parameter_count"] = paramCount
}
return info, nil
}
// buildModelInfo constructs the model info map from config and tensor stats.
@@ -151,6 +160,51 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
return info
}
// getParameterCountFromManifest counts model parameters from tensor shapes.
// This accounts for quantized tensors by using unpacked shapes from
// getTensorInfoFromManifest.
func getParameterCountFromManifest(mf *manifest.Manifest) (int64, error) {
tensors, err := getTensorInfoFromManifest(mf)
if err != nil {
return 0, err
}
var total int64
for _, tensor := range tensors {
if len(tensor.Shape) == 0 {
continue
}
elements := int64(1)
for _, dim := range tensor.Shape {
if dim == 0 {
elements = 0
break
}
if dim > uint64(math.MaxInt64) {
return 0, fmt.Errorf("tensor %s dimension too large: %d", tensor.Name, dim)
}
d := int64(dim)
if elements > math.MaxInt64/d {
return 0, fmt.Errorf("tensor %s element count overflow", tensor.Name)
}
elements *= d
}
if elements == 0 {
continue
}
if total > math.MaxInt64-elements {
return 0, fmt.Errorf("total parameter count overflow")
}
total += elements
}
return total, nil
}
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {

View File

@@ -714,6 +714,187 @@ func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
}
}
func TestGetParameterCountFromManifest(t *testing.T) {
// Create a temp directory for blobs and set OLLAMA_MODELS
tempDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", tempDir)
blobDir := filepath.Join(tempDir, "blobs")
if err := os.MkdirAll(blobDir, 0o755); err != nil {
t.Fatalf("failed to create blobs dir: %v", err)
}
// Unquantized tensor: [4,5] = 20 params
header1 := map[string]any{
"model.embed_tokens.weight": map[string]any{
"dtype": "BF16",
"shape": []int64{4, 5},
"data_offsets": []int64{0, 40},
},
}
header1JSON, _ := json.Marshal(header1)
var buf1 bytes.Buffer
binary.Write(&buf1, binary.LittleEndian, uint64(len(header1JSON)))
buf1.Write(header1JSON)
digest1 := "sha256:1111111111111111111111111111111111111111111111111111111111111111"
blobPath1, err := manifest.BlobsPath(digest1)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath1, buf1.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob1: %v", err)
}
// Quantized int4 tensor with packed shape [10,2] -> unpacked [10,16] = 160 params
header2 := map[string]any{
"__metadata__": map[string]string{
"quant_type": "int4",
"group_size": "32",
},
"model.layers.0.mlp.up_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{10, 2},
"data_offsets": []int64{0, 80},
},
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{10, 1},
"data_offsets": []int64{80, 100},
},
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{10, 1},
"data_offsets": []int64{100, 120},
},
}
header2JSON, _ := json.Marshal(header2)
var buf2 bytes.Buffer
binary.Write(&buf2, binary.LittleEndian, uint64(len(header2JSON)))
buf2.Write(header2JSON)
digest2 := "sha256:2222222222222222222222222222222222222222222222222222222222222222"
blobPath2, err := manifest.BlobsPath(digest2)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath2, buf2.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob2: %v", err)
}
mf := &manifest.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Layers: []manifest.Layer{
{
MediaType: manifest.MediaTypeImageTensor,
Digest: digest1,
Size: int64(buf1.Len() + 40),
Name: "model.embed_tokens.weight",
},
{
MediaType: manifest.MediaTypeImageTensor,
Digest: digest2,
Size: int64(buf2.Len() + 120),
Name: "model.layers.0.mlp.up_proj.weight",
},
},
}
paramCount, err := getParameterCountFromManifest(mf)
if err != nil {
t.Fatalf("getParameterCountFromManifest() error = %v", err)
}
const want int64 = 180 // 20 + 160
if paramCount != want {
t.Errorf("parameter_count = %d, want %d", paramCount, want)
}
}
func TestGetParameterCountFromManifest_MixedQuantizedPacked(t *testing.T) {
// Create a temp directory for blobs and set OLLAMA_MODELS
tempDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", tempDir)
blobDir := filepath.Join(tempDir, "blobs")
if err := os.MkdirAll(blobDir, 0o755); err != nil {
t.Fatalf("failed to create blobs dir: %v", err)
}
// Packed mixed-precision blob (no global metadata):
// - gate_proj: int4 packed [5,8] + scale [5,2] => unpacked [5,64] = 320 params
// - down_proj: int8 packed [5,16] + scale [5,1] => unpacked [5,64] = 320 params
header := map[string]any{
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{5, 8},
"data_offsets": []int64{0, 160},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 2},
"data_offsets": []int64{160, 180},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 2},
"data_offsets": []int64{180, 200},
},
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{5, 16},
"data_offsets": []int64{200, 520},
},
"model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 1},
"data_offsets": []int64{520, 530},
},
"model.layers.0.mlp.experts.0.down_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 1},
"data_offsets": []int64{530, 540},
},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
digest := "sha256:3333333333333333333333333333333333333333333333333333333333333333"
blobPath, err := manifest.BlobsPath(digest)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob: %v", err)
}
mf := &manifest.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Layers: []manifest.Layer{
{
MediaType: manifest.MediaTypeImageTensor,
Digest: digest,
Size: int64(buf.Len() + 540),
Name: "model.layers.0.mlp.experts",
},
},
}
paramCount, err := getParameterCountFromManifest(mf)
if err != nil {
t.Fatalf("getParameterCountFromManifest() error = %v", err)
}
const want int64 = 640 // 320 + 320
if paramCount != want {
t.Errorf("parameter_count = %d, want %d", paramCount, want)
}
}
func TestParseSafetensorsAllHeaders(t *testing.T) {
tests := []struct {
name string

108
x/tokenizer/tokenizer.go Normal file
View File

@@ -0,0 +1,108 @@
//go:build mlx
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
//
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
// - GPT-2 byte-level encoding (OpenAI tiktoken)
// - HuggingFace tokenizer.json pretokenizer patterns
// - SentencePiece ▁-style space handling
package tokenizer
import "regexp"
// TokenizerType identifies the tokenization algorithm
type TokenizerType int
const (
TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
TokenizerSentencePiece // SentencePiece with ▁ for spaces
)
// Vocabulary holds the tokenizer vocabulary and merges
type Vocabulary struct {
Values []string
Reverse map[string]int32
Merges map[string]int
BOS int32
EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has <eos> and <end_of_turn>)
PAD int32 // Padding token (often <|endoftext|> or <pad>)
AddBOS bool
AddEOS bool
// Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
byteTokens [256]int32
}
// Tokenizer handles BPE and SentencePiece tokenization
type Tokenizer struct {
vocab *Vocabulary
pretokenizer *regexp.Regexp
specialTokens map[string]int32 // Special tokens for direct lookup
sortedSpecialTokens []string // Special tokens sorted by length, longest first
typ TokenizerType // Algorithm type
}
// Precomputed GPT-2 byte-level encoding table
// Maps byte values to their encoded rune equivalents
var byteToRune [256]rune
func init() {
for b := 0; b < 256; b++ {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
byteToRune[b] = r
}
}
// VocabSize returns the vocabulary size
func (t *Tokenizer) VocabSize() int {
return len(t.vocab.Values)
}
// BOS returns the beginning of sequence token ID
func (t *Tokenizer) BOS() int32 {
return t.vocab.BOS
}
// EOS returns the first end of sequence token ID (for backwards compatibility)
func (t *Tokenizer) EOS() int32 {
if len(t.vocab.EOS) > 0 {
return t.vocab.EOS[0]
}
return -1
}
// EOSTokens returns all end of sequence token IDs
func (t *Tokenizer) EOSTokens() []int32 {
return t.vocab.EOS
}
// PAD returns the padding token ID, or -1 if not set
func (t *Tokenizer) PAD() int32 {
return t.vocab.PAD
}
// IsEOS returns true if the token ID is an end of sequence token
func (t *Tokenizer) IsEOS(id int32) bool {
for _, eos := range t.vocab.EOS {
if id == eos {
return true
}
}
return false
}
// GetSpecialToken returns the token ID for a special token string
func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
id, ok := t.specialTokens[name]
return id, ok
}

View File

@@ -0,0 +1,251 @@
//go:build mlx
package tokenizer
import (
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
var (
benchmarkSinkIDs []int32
benchmarkSinkStr string
benchmarkSinkTok *Tokenizer
)
const benchmarkWordPieceJSON = `{
"model": {
"type": "WordPiece",
"vocab": {
"[UNK]": 0,
"hello": 1,
"##world": 2,
"##ly": 3,
"##hello": 4
}
},
"added_tokens": []
}`
const benchmarkSentencePieceJSON = `{
"model": {
"type": "BPE",
"vocab": {
"\u2581": 0,
"h": 1,
"e": 2,
"l": 3,
"o": 4,
"w": 5,
"r": 6,
"d": 7,
"<0x0A>": 8
},
"merges": []
},
"decoder": {
"type": "Sequence",
"decoders": [
{
"type": "Replace",
"pattern": {
"String": "\u2581"
}
}
]
},
"added_tokens": []
}`
func benchmarkMiniLlamaPath(tb testing.TB) string {
tb.Helper()
_, filename, _, ok := runtime.Caller(0)
if !ok {
tb.Fatal("failed to resolve benchmark file path")
}
return filepath.Join(filepath.Dir(filename), "..", "imagegen", "tokenizer", "testdata", "mini_llama.json")
}
func benchmarkLoadMiniLlama(tb testing.TB) *Tokenizer {
tb.Helper()
data := benchmarkLoadMiniLlamaBytes(tb)
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load mini llama tokenizer: %v", err)
}
return tok
}
func benchmarkLoadMiniLlamaBytes(tb testing.TB) []byte {
tb.Helper()
data, err := os.ReadFile(benchmarkMiniLlamaPath(tb))
if err != nil {
tb.Fatalf("failed to read mini llama tokenizer: %v", err)
}
return data
}
func benchmarkLoadFromBytes(tb testing.TB, data []byte) *Tokenizer {
tb.Helper()
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load tokenizer from bytes: %v", err)
}
return tok
}
func BenchmarkTokenizerEncodeBPE(b *testing.B) {
tok := benchmarkLoadMiniLlama(b)
inputs := []struct {
name string
text string
}{
{name: "short", text: "Hello, world!"},
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
{name: "long_sequential", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 80)},
{name: "long_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
{name: "huge_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)},
{name: "special_tokens", text: "<|begin_of_text|>system\nYou are concise.<|end_of_text|>"},
}
for _, input := range inputs {
b.Run(input.name, func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(input.text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(input.text, false)
}
})
}
}
func BenchmarkTokenizerDecodeBPE(b *testing.B) {
tok := benchmarkLoadMiniLlama(b)
inputs := []struct {
name string
text string
}{
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
{name: "long", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
}
for _, input := range inputs {
ids := tok.Encode(input.text, false)
b.Run(input.name, func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(input.text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
})
}
}
func BenchmarkTokenizerLoadFromBytes(b *testing.B) {
data := benchmarkLoadMiniLlamaBytes(b)
config := &TokenizerConfig{
TokenizerConfigJSON: []byte(`{
"bos_token": {"content": "<|begin_of_text|>"},
"eos_token": {"content": "<|end_of_text|>"},
"add_bos_token": true
}`),
GenerationConfigJSON: []byte(`{"bos_token_id": 128000, "eos_token_id": 128001}`),
}
b.Run("without_config", func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
tok, err := LoadFromBytes(data)
if err != nil {
b.Fatalf("LoadFromBytes failed: %v", err)
}
benchmarkSinkTok = tok
}
})
b.Run("with_config", func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
tok, err := LoadFromBytesWithConfig(data, config)
if err != nil {
b.Fatalf("LoadFromBytesWithConfig failed: %v", err)
}
benchmarkSinkTok = tok
}
})
}
func BenchmarkTokenizerEncodeWordPiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
text := strings.Repeat("helloworldly", 16)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(text, false)
}
}
func BenchmarkTokenizerDecodeWordPiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
text := strings.Repeat("helloworldly", 16)
ids := tok.Encode(text, false)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
}
func BenchmarkTokenizerEncodeSentencePiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
text := strings.Repeat("hello world\n", 64)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(text, false)
}
}
func BenchmarkTokenizerDecodeSentencePiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
text := strings.Repeat("hello world\n", 64)
ids := tok.Encode(text, false)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
}

View File

@@ -0,0 +1,175 @@
//go:build mlx
package tokenizer
import "container/heap"
type bpeMergeNode struct {
prev int
next int
token string
}
type bpePair struct {
left int
right int
rank int
value string
}
type bpePairHeap []*bpePair
func (h bpePairHeap) Len() int { return len(h) }
func (h bpePairHeap) Less(i, j int) bool {
return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left)
}
func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *bpePairHeap) Push(x any) {
*h = append(*h, x.(*bpePair))
}
func (h *bpePairHeap) Pop() any {
old := *h
n := len(old)
item := old[n-1]
*h = old[:n-1]
return item
}
// encodeBPEMerge encodes using BPE merge algorithm.
// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go:
// merge the lowest-rank valid pair, then only recheck adjacent pairs.
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
runes := []rune(encoded)
if len(runes) == 0 {
return ids
}
nodes := make([]bpeMergeNode, len(runes))
for i := range runes {
nodes[i] = bpeMergeNode{
prev: i - 1,
next: i + 1,
token: string(runes[i]),
}
}
pairwise := func(left, right int) *bpePair {
if left < 0 || right >= len(nodes) {
return nil
}
if nodes[left].token == "" || nodes[right].token == "" {
return nil
}
leftToken, rightToken := nodes[left].token, nodes[right].token
rank, ok := t.vocab.Merges[leftToken+" "+rightToken]
if !ok {
return nil
}
value := leftToken + rightToken
if _, ok := t.vocab.Reverse[value]; !ok {
return nil
}
return &bpePair{
left: left,
right: right,
rank: rank,
value: value,
}
}
pairs := bpePairHeap{}
heap.Init(&pairs)
for i := 0; i < len(runes)-1; i++ {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(&pairs, pair)
}
}
for pairs.Len() > 0 {
pair := heap.Pop(&pairs).(*bpePair)
left, right := nodes[pair.left], nodes[pair.right]
if left.token == "" || right.token == "" {
continue
}
if left.next != pair.right || right.prev != pair.left {
continue
}
if left.token+right.token != pair.value {
continue
}
nodes[pair.left].token = pair.value
nodes[pair.right].token = ""
nodes[pair.left].next = right.next
if right.next < len(nodes) {
nodes[right.next].prev = pair.left
}
if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil {
heap.Push(&pairs, pair)
}
if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil {
heap.Push(&pairs, pair)
}
}
for _, node := range nodes {
if node.token == "" {
continue
}
if id, ok := t.vocab.Reverse[node.token]; ok {
ids = append(ids, id)
continue
}
ids = t.appendByteFallback(ids, node.token)
}
return ids
}
func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 {
if t.typ == TokenizerBPE {
for _, r := range token {
if b, ok := decodeByteLevelRune(r); ok {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
}
return ids
}
// SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens.
for _, b := range []byte(token) {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
return ids
}
func decodeByteLevelRune(r rune) (byte, bool) {
switch {
case r >= 0x00 && r <= 0xFF:
return byte(r), true
case r == 0x0100:
return 0x00, true
case r == 0x0143:
return 0x00ad, true
case r > 0x0100 && r <= 0x0120:
return byte(r - 0x0100), true
case r > 0x0120 && r <= 0x0142:
return byte(r - 0x00a2), true
default:
return 0, false
}
}

View File

@@ -0,0 +1,137 @@
//go:build mlx
package tokenizer
import (
"runtime"
"strings"
"testing"
)
func equalIDs(a, b []int32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func TestEncodeRoundtripMiniLlama(t *testing.T) {
tok := benchmarkLoadMiniLlama(t)
inputs := []string{
"",
"hello",
"hello world",
" hello world ",
"don't we'll they're",
"1234567890",
"こんにちは世界",
"Hello 世界",
"func main() {}",
"<|begin_of_text|>system\nYou are concise.<|end_of_text|>",
strings.Repeat("The quick brown fox jumps over the lazy dog. ", 32),
}
for _, input := range inputs {
ids := tok.Encode(input, false)
got := tok.Decode(ids)
if got != input {
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
}
}
}
func TestSplitBySpecialTokensGreedyLongest(t *testing.T) {
data := []byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0, "b": 1},
"merges": []
},
"added_tokens": [
{"id": 2, "content": "<tag>", "special": true},
{"id": 3, "content": "<tag>x", "special": true}
]
}`)
tok, err := LoadFromBytes(data)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
input := "a<tag>xb"
want := []string{"a", "<tag>x", "b"}
got := tok.splitBySpecialTokens(input)
if len(got) != len(want) {
t.Fatalf("split length mismatch: got %v want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
}
}
}
func TestSplitBySpecialTokensFallbackWithoutCache(t *testing.T) {
data := []byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0, "b": 1},
"merges": []
},
"added_tokens": [
{"id": 2, "content": "<tag>", "special": true},
{"id": 3, "content": "<tag>x", "special": true}
]
}`)
tok, err := LoadFromBytes(data)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
input := "a<tag>xb"
want := []string{"a", "<tag>x", "b"}
// Simulate construction outside loader path where cache is not set.
tok.sortedSpecialTokens = nil
got := tok.splitBySpecialTokens(input)
if len(got) != len(want) {
t.Fatalf("split length mismatch: got %v want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
}
}
}
func TestEncodeDeterministicAcrossGOMAXPROCS(t *testing.T) {
tok := benchmarkLoadMiniLlama(t)
input := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)
prev := runtime.GOMAXPROCS(0)
defer runtime.GOMAXPROCS(prev)
runtime.GOMAXPROCS(1)
seq := tok.Encode(input, false)
if prev < 2 {
runtime.GOMAXPROCS(2)
} else {
runtime.GOMAXPROCS(prev)
}
par := tok.Encode(input, false)
if !equalIDs(seq, par) {
t.Fatalf("encode mismatch between sequential and parallel paths: seq=%d par=%d", len(seq), len(par))
}
}

View File

@@ -0,0 +1,56 @@
//go:build mlx
package tokenizer
import (
"strconv"
"strings"
)
// Decode converts token IDs back to text
func (t *Tokenizer) Decode(ids []int32) string {
var sb strings.Builder
for _, id := range ids {
if int(id) >= len(t.vocab.Values) {
continue
}
token := t.vocab.Values[id]
switch t.typ {
case TokenizerSentencePiece:
// SentencePiece style: replace ▁ with space, decode byte tokens
token = strings.ReplaceAll(token, "▁", " ")
// Handle byte fallback tokens like <0x0D>
if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
sb.WriteByte(byte(v))
continue
}
}
sb.WriteString(token)
default:
// GPT-2 BPE style: decode byte-level encoding
for _, r := range token {
switch {
case r == 0x0100:
// Mirror GGML tokenizer behavior for NULL byte.
// 0x00 is omitted during decode.
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// Write as byte, not UTF-8 encoded rune
sb.WriteByte(byte(r))
}
}
}
return sb.String()
}

View File

@@ -0,0 +1,289 @@
//go:build mlx
package tokenizer
import (
"runtime"
"sort"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
const (
encodeParallelMinInputBytes = 4 * 1024
encodeParallelMinChunksPerWorker = 8
)
type tokenMatch struct {
start int
end int
}
type encodeChunk struct {
text string
isSpecial bool
}
// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
func isNonNewlineWhitespace(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if r == '\n' || r == '\r' {
return false
}
if !unicode.IsSpace(r) {
return false
}
}
return true
}
// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
func (t *Tokenizer) splitBySpecialTokens(s string) []string {
if len(t.specialTokens) == 0 {
return []string{s}
}
tokens := t.sortedSpecialTokens
if len(tokens) == 0 {
// Fallback for tokenizers constructed outside the loaders.
tokens = make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
}
var result []string
remaining := s
for len(remaining) > 0 {
found := false
for _, tok := range tokens {
if strings.HasPrefix(remaining, tok) {
result = append(result, tok)
remaining = remaining[len(tok):]
found = true
break
}
}
if !found {
// Find next special token position
nextPos := len(remaining)
for _, tok := range tokens {
if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
nextPos = idx
}
}
if nextPos > 0 {
result = append(result, remaining[:nextPos])
}
remaining = remaining[nextPos:]
}
}
return result
}
func adjustWhitespaceBoundary(part string, curr, next *tokenMatch) {
m := part[curr.start:curr.end]
nextText := part[next.start:next.end]
if !isNonNewlineWhitespace(m) || len(nextText) == 0 {
return
}
firstRune, _ := utf8.DecodeRuneInString(nextText)
if !unicode.IsLetter(firstRune) {
return
}
lastSpaceStart := curr.end
for j := curr.end; j > curr.start; {
r, size := utf8.DecodeLastRuneInString(part[curr.start:j])
if unicode.IsSpace(r) {
lastSpaceStart = j - size
break
}
j -= size
}
if lastSpaceStart > curr.start {
curr.end = lastSpaceStart
next.start = lastSpaceStart
} else {
next.start = curr.start
curr.end = curr.start
}
}
func (t *Tokenizer) forEachPartChunk(part string, fn func(encodeChunk)) {
if _, ok := t.specialTokens[part]; ok {
fn(encodeChunk{text: part, isSpecial: true})
return
}
if t.pretokenizer == nil {
fn(encodeChunk{text: part, isSpecial: false})
return
}
re := t.pretokenizer
offset := 0
loc := re.FindStringIndex(part[offset:])
if loc == nil {
return
}
curr := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
offset += loc[1]
for {
loc = re.FindStringIndex(part[offset:])
if loc == nil {
if curr.end > curr.start {
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
}
return
}
next := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
offset += loc[1]
adjustWhitespaceBoundary(part, &curr, &next)
if curr.end > curr.start {
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
}
curr = next
}
}
func (t *Tokenizer) appendEncodedChunk(ids []int32, c encodeChunk) []int32 {
if c.isSpecial {
if id, ok := t.specialTokens[c.text]; ok {
return append(ids, id)
}
return ids
}
return t.encodeChunkInto(c.text, ids)
}
// Encode tokenizes text to token IDs.
// Parallel encoding is used only for very large inputs with enough chunks per worker.
func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
// First: split by special tokens
parts := t.splitBySpecialTokens(s)
// Fast path: encode sequentially without materializing chunk slices.
if len(s) < encodeParallelMinInputBytes {
var ids []int32
for _, part := range parts {
t.forEachPartChunk(part, func(c encodeChunk) {
ids = t.appendEncodedChunk(ids, c)
})
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// For large inputs collect chunks to enable parallel processing.
var allChunks []encodeChunk
for _, part := range parts {
t.forEachPartChunk(part, func(c encodeChunk) {
allChunks = append(allChunks, c)
})
}
// Encode chunks. Use the parallel path only when the chunk count is
// large enough to amortize goroutine/synchronization overhead.
useParallel := true
numWorkers := runtime.GOMAXPROCS(0)
if numWorkers > len(allChunks) {
numWorkers = len(allChunks)
}
if numWorkers < 2 || len(allChunks) < numWorkers*encodeParallelMinChunksPerWorker {
useParallel = false
}
var ids []int32
if !useParallel {
for _, c := range allChunks {
ids = t.appendEncodedChunk(ids, c)
}
} else {
chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
results := make([][]int32, numWorkers)
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
start := i * chunksPer
end := start + chunksPer
if end > len(allChunks) {
end = len(allChunks)
}
if start >= end {
continue
}
wg.Add(1)
go func(i int, chunks []encodeChunk) {
defer wg.Done()
var r []int32
for _, c := range chunks {
r = t.appendEncodedChunk(r, c)
}
results[i] = r
}(i, allChunks[start:end])
}
wg.Wait()
for _, r := range results {
ids = append(ids, r...)
}
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// encodeChunkInto appends encoded tokens to ids and returns the extended slice.
// Uses BPE merge algorithm for both BPE and SentencePiece tokenization.
func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
if s == "" {
return ids
}
// Apply encoding transformation
// SentencePiece: replace space with ▁
// BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
var encoded string
if t.typ == TokenizerSentencePiece {
encoded = strings.ReplaceAll(s, " ", "▁")
} else {
var sb strings.Builder
sb.Grow(len(s) * 2)
for i := 0; i < len(s); i++ {
sb.WriteRune(byteToRune[s[i]])
}
encoded = sb.String()
}
// Fast path: check if entire chunk is a single token
if id, ok := t.vocab.Reverse[encoded]; ok {
return append(ids, id)
}
return t.encodeBPEMerge(encoded, ids)
}

View File

@@ -0,0 +1,207 @@
//go:build mlx
package tokenizer
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func llama32GGMLFixturePath(tb testing.TB, file string) string {
tb.Helper()
_, filename, _, ok := runtime.Caller(0)
if !ok {
tb.Fatal("failed to resolve test file path")
}
return filepath.Join(filepath.Dir(filename), "..", "..", "tokenizer", "testdata", "llama3.2", file)
}
func loadLlama32FromGGMLFixture(tb testing.TB) *Tokenizer {
tb.Helper()
f, err := os.Open(llama32GGMLFixturePath(tb, "encoder.json"))
if err != nil {
tb.Fatalf("failed to open encoder.json: %v", err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
tb.Fatalf("failed to decode encoder.json: %v", err)
}
type addedToken struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
}
var addedTokens []addedToken
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
if _, ok := vocab[token]; !ok {
id := int32(len(vocab))
vocab[token] = id
addedTokens = append(addedTokens, addedToken{ID: id, Content: token, Special: true})
}
}
mf, err := os.Open(llama32GGMLFixturePath(tb, "vocab.bpe"))
if err != nil {
tb.Fatalf("failed to open vocab.bpe: %v", err)
}
defer mf.Close()
var merges []string
scanner := bufio.NewScanner(mf)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") {
continue
}
line = strings.TrimSpace(line)
if line != "" {
merges = append(merges, line)
}
}
if err := scanner.Err(); err != nil {
tb.Fatalf("failed to read vocab.bpe: %v", err)
}
payload := struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges []string `json:"merges"`
} `json:"model"`
PreTokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
AddedTokens []addedToken `json:"added_tokens"`
}{}
payload.Model.Type = "BPE"
payload.Model.Vocab = vocab
payload.Model.Merges = merges
payload.PreTokenizer.Type = "Sequence"
payload.PreTokenizer.Pretokenizers = []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}{
{
Type: "Split",
Pattern: struct {
Regex string `json:"Regex"`
}{
Regex: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
},
},
}
payload.AddedTokens = addedTokens
data, err := json.Marshal(payload)
if err != nil {
tb.Fatalf("failed to marshal synthetic tokenizer.json: %v", err)
}
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load tokenizer from fixture data: %v", err)
}
return tok
}
func TestGGMLLlamaKnownEncodings(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := map[string][]int32{
"hello world": {15339, 1917},
"hello <|end_of_text|>": {15339, 220, 128001},
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
}
for input, want := range cases {
got := tok.Encode(input, false)
if !equalIDs(got, want) {
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
}
}
}
func TestGGMLLlamaRepeatedZeros(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := map[int][]int32{
1: {15},
2: {410},
3: {931},
4: {931, 15},
5: {931, 410},
6: {931, 931},
7: {931, 931, 15},
8: {931, 931, 410},
9: {931, 931, 931},
10: {931, 931, 931, 15},
11: {931, 931, 931, 410},
12: {931, 931, 931, 931},
13: {931, 931, 931, 931, 15},
14: {931, 931, 931, 931, 410},
15: {931, 931, 931, 931, 931},
16: {931, 931, 931, 931, 931, 15},
17: {931, 931, 931, 931, 931, 410},
}
for n, want := range cases {
input := strings.Repeat("0", n)
got := tok.Encode(input, false)
if !equalIDs(got, want) {
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
}
}
}
func TestGGMLLlamaRoundtripAndByteBehavior(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
}
for _, input := range cases {
ids := tok.Encode(input, false)
got := tok.Decode(ids)
if got != input {
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
}
}
// Match GGML tokenizer behavior: 0x00 is omitted when decoding.
ids := tok.Encode(string(rune(0x00)), false)
got := tok.Decode(ids)
if got != "" {
t.Fatalf("expected empty decode for 0x00, got %q (ids=%v)", got, ids)
}
}

View File

@@ -0,0 +1,458 @@
//go:build mlx
package tokenizer
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
)
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
type TokenizerConfig struct {
TokenizerConfigJSON []byte // tokenizer_config.json content
GenerationConfigJSON []byte // generation_config.json content
SpecialTokensMapJSON []byte // special_tokens_map.json content
ConfigJSON []byte // config.json content
}
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
// This is useful when loading from blob storage where the file content is already in memory.
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
func LoadFromBytes(data []byte) (*Tokenizer, error) {
return loadFromTokenizerJSON(data)
}
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
// This is useful when loading from blob storage where companion config files are also blobs.
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
t, err := loadFromTokenizerJSON(data)
if err != nil {
return nil, err
}
if config == nil {
return t, nil
}
// Apply special token configs from provided data
loadSpecialTokenConfigFromBytes(t, config)
return t, nil
}
// loadFromTokenizerJSON parses tokenizer.json content from bytes.
func loadFromTokenizerJSON(data []byte) (*Tokenizer, error) {
var raw struct {
Model struct {
Type string `json:"type"` // "BPE"
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
}
// Covers SentencePiece and BPE models
if raw.Model.Type != "BPE" {
return nil, fmt.Errorf("unsupported tokenizer type: %s", raw.Model.Type)
}
// Parse merges - can be []string (Llama) or [][]string (GPT-OSS).
var mergesStrings []string
if raw.Model.Merges != nil {
var mergesArrays [][]string
if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
// Try array of arrays format
if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
return nil, fmt.Errorf("failed to parse merges: %w", err)
}
// Convert [][]string to []string
mergesStrings = make([]string, len(mergesArrays))
for i, pair := range mergesArrays {
if len(pair) != 2 {
return nil, fmt.Errorf("failed to parse merges: expected merge pair of length 2, got %d", len(pair))
}
mergesStrings[i] = pair[0] + " " + pair[1]
}
}
}
// Build tokenizer
t := &Tokenizer{
vocab: &Vocabulary{
Values: make([]string, len(raw.Model.Vocab)),
Reverse: raw.Model.Vocab,
Merges: make(map[string]int, len(mergesStrings)),
BOS: -1,
PAD: -1,
},
specialTokens: make(map[string]int32),
}
// Build values array
for token, id := range raw.Model.Vocab {
if int(id) >= len(t.vocab.Values) {
newValues := make([]string, id+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[id] = token
}
// Build merges map
for i, merge := range mergesStrings {
t.vocab.Merges[merge] = i
}
// Add all added_tokens to vocabulary and special tokens map.
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
// they bypass BPE and get their own token ID. The "special" flag just indicates
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
// to treat all added_tokens as special to match HuggingFace behavior.
for _, tok := range raw.AddedTokens {
if int(tok.ID) >= len(t.vocab.Values) {
newValues := make([]string, tok.ID+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[tok.ID] = tok.Content
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
}
// Precompute byte token IDs for <0xNN> fallback
initByteTokens(t)
// Determine tokenizer type
switch {
case detectSentencePiece(raw.Decoder):
t.typ = TokenizerSentencePiece
default:
t.typ = TokenizerBPE
}
// Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
if t.typ == TokenizerBPE {
pattern := extractPretokenizer(raw.PreTokenizer)
if pattern == "" {
pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
}
re, err := regexp.Compile(rewritePatternForRE2(pattern))
if err != nil {
return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
}
t.pretokenizer = re
}
cacheSortedSpecialTokens(t)
return t, nil
}
func cacheSortedSpecialTokens(t *Tokenizer) {
if len(t.specialTokens) == 0 {
t.sortedSpecialTokens = nil
return
}
tokens := make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
t.sortedSpecialTokens = tokens
}
type specialTokenConfigData struct {
tokenizerConfigJSON []byte
generationConfigJSON []byte
specialTokensMapJSON []byte
configJSON []byte
}
func applySpecialTokenConfig(t *Tokenizer, config specialTokenConfigData) {
parseTokenIDs := func(v interface{}) []int32 {
switch val := v.(type) {
case float64:
return []int32{int32(val)}
case []interface{}:
ids := make([]int32, 0, len(val))
for _, id := range val {
if f, ok := id.(float64); ok {
ids = append(ids, int32(f))
}
}
return ids
}
return nil
}
// Priority 1: generation_config.json
if len(config.generationConfigJSON) > 0 {
var genConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.generationConfigJSON, &genConfig); err == nil {
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
// Priority 2: config.json
if len(config.configJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
var modelConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.configJSON, &modelConfig); err == nil {
if len(t.vocab.EOS) == 0 {
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
}
if t.vocab.BOS < 0 {
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
}
// Priority 3: tokenizer_config.json
if len(config.tokenizerConfigJSON) > 0 {
var tokConfig struct {
BOSToken interface{} `json:"bos_token"`
EOSToken interface{} `json:"eos_token"`
PADToken interface{} `json:"pad_token"`
AddBOSToken *bool `json:"add_bos_token"`
AddEOSToken *bool `json:"add_eos_token"`
}
if err := json.Unmarshal(config.tokenizerConfigJSON, &tokConfig); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
if tokConfig.AddBOSToken != nil {
t.vocab.AddBOS = *tokConfig.AddBOSToken
}
if tokConfig.AddEOSToken != nil {
t.vocab.AddEOS = *tokConfig.AddEOSToken
}
}
}
// Priority 4: special_tokens_map.json
if len(config.specialTokensMapJSON) > 0 {
var tokensMap map[string]interface{}
if err := json.Unmarshal(config.specialTokensMapJSON, &tokensMap); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
}
}
}
// extractTokenString extracts the token string from various formats used in HuggingFace configs.
// Tokens can be represented as:
// - string: "token"
// - object: {"content": "token", ...}
func extractTokenString(v interface{}) string {
if v == nil {
return ""
}
// Direct string
if s, ok := v.(string); ok {
return s
}
// Object with content field
if m, ok := v.(map[string]interface{}); ok {
if content, ok := m["content"].(string); ok {
return content
}
}
return ""
}
// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
// - (?!\S) negative lookahead - RE2 doesn't support this
// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
//
// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
func rewritePatternForRE2(pattern string) string {
// Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
// Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
// IMPORTANT: Must be done before the non-optional version to avoid partial replacement
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
// Expand case-insensitive contraction pattern to explicit alternations
// (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
return pattern
}
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
applySpecialTokenConfig(t, specialTokenConfigData{
tokenizerConfigJSON: config.TokenizerConfigJSON,
generationConfigJSON: config.GenerationConfigJSON,
specialTokensMapJSON: config.SpecialTokensMapJSON,
configJSON: config.ConfigJSON,
})
}
// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
// vs GPT-2 byte-level encoding
func detectSentencePiece(data json.RawMessage) bool {
if data == nil {
return false
}
// Check for Sequence decoder with Replace step (SentencePiece style)
var seq struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"String"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(data, &seq); err == nil {
if seq.Type == "Sequence" {
for _, dec := range seq.Decoders {
// Look for Replace decoder that converts ▁ to space
if dec.Type == "Replace" && dec.Pattern.String == "▁" {
return true
}
}
}
}
// Check for direct ByteLevel decoder (GPT-2 style)
var simple struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &simple); err == nil {
if simple.Type == "ByteLevel" {
return false
}
}
return false
}
// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
func initByteTokens(t *Tokenizer) {
for i := range t.vocab.byteTokens {
t.vocab.byteTokens[i] = -1
}
for b := 0; b < 256; b++ {
token := fmt.Sprintf("<0x%02X>", b)
if id, ok := t.vocab.Reverse[token]; ok {
t.vocab.byteTokens[b] = id
}
}
}
// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
func extractPretokenizer(data json.RawMessage) string {
if data == nil {
return ""
}
// Try to parse as a single Split pretokenizer
var single struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}
if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
return single.Pattern.Regex
}
// Try to parse as Sequence of pretokenizers - use first Split pattern
var seq struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
}
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
for _, pt := range seq.Pretokenizers {
if pt.Type == "Split" && pt.Pattern.Regex != "" {
return pt.Pattern.Regex
}
}
}
return ""
}

View File

@@ -0,0 +1,26 @@
//go:build mlx
package tokenizer
import (
"strings"
"testing"
)
func TestLoadFromBytesRejectsWordPiece(t *testing.T) {
data := []byte(`{
"model": {
"type": "WordPiece",
"vocab": {"[UNK]": 0, "hello": 1}
},
"added_tokens": []
}`)
_, err := LoadFromBytes(data)
if err == nil {
t.Fatal("expected WordPiece load to fail")
}
if !strings.Contains(err.Error(), "unsupported tokenizer type: WordPiece") {
t.Fatalf("unexpected error: %v", err)
}
}