Compare commits

..

2 Commits

Author SHA1 Message Date
Bruce MacDonald
365a3657ad fix test home on windows 2026-02-18 18:37:24 -08:00
Bruce MacDonald
71c1d8d0a9 cmd: ollama update
Add interactive update check to CLI TUI and `ollama update` command

On TUI launch, check for updates in the background and cache the result as a marker file (~/.ollama/update). On the next launch, if a cached update exists, print a one-line notice before the TUI starts. The check is skipped for dev builds (0.0.0), alternative installs (e.g. brew, choco), and remote Ollama hosts.

Add `ollama update` subcommand that downloads and runs the platform-appropriate install script (install.sh on Linux/macOS, install.ps1 on Windows). Refuses to run if the binary wasn't installed via official channels unless --force is passed.

Co-Authored-By: RajeshKumar11 <22585507+rajeshkumar11@users.noreply.github.com>
2026-02-18 18:21:17 -08:00
57 changed files with 1461 additions and 4452 deletions

View File

@@ -1 +1 @@
v0.4.1
v0.5.0

View File

@@ -41,11 +41,6 @@ type InferenceCompute struct {
VRAM string
}
type InferenceInfo struct {
Computes []InferenceCompute
DefaultContextLength int
}
func New(s *store.Store, devMode bool) *Server {
p := resolvePath("ollama")
return &Server{store: s, bin: p, dev: devMode}
@@ -277,12 +272,9 @@ func openRotatingLog() (io.WriteCloser, error) {
// Attempt to retrieve inference compute information from the server
// log. Set ctx to timeout to control how long to wait for the logs to appear
func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
info := &InferenceInfo{}
computeMarker := regexp.MustCompile(`inference compute.*library=`)
defaultCtxMarker := regexp.MustCompile(`vram-based default context`)
defaultCtxRegex := regexp.MustCompile(`default_num_ctx=(\d+)`)
func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
inference := []InferenceCompute{}
marker := regexp.MustCompile(`inference compute.*library=`)
q := `inference compute.*%s=["]([^"]*)["]`
nq := `inference compute.*%s=(\S+)\s`
type regex struct {
@@ -348,8 +340,8 @@ func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
// Check for inference compute lines
if computeMarker.MatchString(line) {
match := marker.FindStringSubmatch(line)
if len(match) > 0 {
ic := InferenceCompute{
Library: get("library", line),
Variant: get("variant", line),
@@ -360,25 +352,12 @@ func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
}
slog.Info("Matched", "inference compute", ic)
info.Computes = append(info.Computes, ic)
continue
}
// Check for default context length line
if defaultCtxMarker.MatchString(line) {
match := defaultCtxRegex.FindStringSubmatch(line)
if len(match) > 1 {
numCtx, err := strconv.Atoi(match[1])
if err == nil {
info.DefaultContextLength = numCtx
slog.Info("Matched default context length", "default_num_ctx", numCtx)
}
inference = append(inference, ic)
} else {
// Break out on first non matching line after we start matching
if len(inference) > 0 {
return inference, nil
}
return info, nil
}
// If we've found compute info but hit a non-matching line, return what we have
// This handles older server versions that don't log the default context line
if len(info.Computes) > 0 {
return info, nil
}
}
time.Sleep(100 * time.Millisecond)

View File

@@ -205,50 +205,44 @@ func TestServerCmdCloudSettingEnv(t *testing.T) {
}
}
func TestGetInferenceInfo(t *testing.T) {
func TestGetInferenceComputer(t *testing.T) {
tests := []struct {
name string
log string
expComputes []InferenceCompute
expDefaultCtxLen int
name string
log string
exp []InferenceCompute
}{
{
name: "metal",
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
time=2025-06-30T09:23:07.417-07:00 level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="96.0 GiB" default_num_ctx=262144
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
`,
expComputes: []InferenceCompute{{
exp: []InferenceCompute{{
Library: "metal",
Driver: "0.0",
VRAM: "96.0 GiB",
}},
expDefaultCtxLen: 262144,
},
{
name: "cpu",
log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered"
time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB"
time=2025-07-01T17:59:51.471Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="31.3 GiB" default_num_ctx=32768
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
expComputes: []InferenceCompute{{
exp: []InferenceCompute{{
Library: "cpu",
Driver: "0.0",
VRAM: "31.3 GiB",
}},
expDefaultCtxLen: 32768,
},
{
name: "cuda1",
log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu"
releasing cuda driver library
time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB"
time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="3.9 GiB" default_num_ctx=4096
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
expComputes: []InferenceCompute{{
exp: []InferenceCompute{{
Library: "cuda",
Variant: "v12",
Compute: "6.1",
@@ -256,7 +250,6 @@ time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based d
Name: "NVIDIA GeForce GT 1030",
VRAM: "3.9 GiB",
}},
expDefaultCtxLen: 4096,
},
{
name: "frank",
@@ -264,10 +257,9 @@ time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based d
releasing cuda driver library
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB"
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB"
time=2025-07-01T19:36:13.316Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="26.6 GiB" default_num_ctx=32768
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
`,
expComputes: []InferenceCompute{
exp: []InferenceCompute{
{
Library: "cuda",
Variant: "v12",
@@ -284,20 +276,6 @@ time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based d
VRAM: "16.0 GiB",
},
},
expDefaultCtxLen: 32768,
},
{
name: "missing_default_context",
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
`,
expComputes: []InferenceCompute{{
Library: "metal",
Driver: "0.0",
VRAM: "96.0 GiB",
}},
expDefaultCtxLen: 0, // No default context line, should return 0
},
}
for _, tt := range tests {
@@ -310,21 +288,18 @@ time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not f
}
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
defer cancel()
info, err := GetInferenceInfo(ctx)
ics, err := GetInferenceComputer(ctx)
if err != nil {
t.Fatalf("failed to get inference info: %v", err)
t.Fatalf(" failed to get inference compute: %v", err)
}
if !reflect.DeepEqual(info.Computes, tt.expComputes) {
t.Fatalf("computes mismatch\ngot:\n%#v\nwant:\n%#v", info.Computes, tt.expComputes)
}
if info.DefaultContextLength != tt.expDefaultCtxLen {
t.Fatalf("default context length mismatch: got %d, want %d", info.DefaultContextLength, tt.expDefaultCtxLen)
if !reflect.DeepEqual(ics, tt.exp) {
t.Fatalf("got:\n%#v\nwant:\n%#v", ics, tt.exp)
}
})
}
}
func TestGetInferenceInfoTimeout(t *testing.T) {
func TestGetInferenceComputerTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
defer cancel()
tmpDir := t.TempDir()
@@ -333,7 +308,7 @@ func TestGetInferenceInfoTimeout(t *testing.T) {
if err != nil {
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
}
_, err = GetInferenceInfo(ctx)
_, err = GetInferenceComputer(ctx)
if err == nil {
t.Fatal("expected timeout")
}

View File

@@ -14,7 +14,7 @@ import (
// currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations.
const currentSchemaVersion = 14
const currentSchemaVersion = 13
// database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access:
@@ -73,7 +73,7 @@ func (db *database) init() error {
agent BOOLEAN NOT NULL DEFAULT 0,
tools BOOLEAN NOT NULL DEFAULT 0,
working_dir TEXT NOT NULL DEFAULT '',
context_length INTEGER NOT NULL DEFAULT 0,
context_length INTEGER NOT NULL DEFAULT 4096,
window_width INTEGER NOT NULL DEFAULT 0,
window_height INTEGER NOT NULL DEFAULT 0,
config_migrated BOOLEAN NOT NULL DEFAULT 0,
@@ -251,12 +251,6 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v12 to v13: %w", err)
}
version = 13
case 13:
// change default context_length from 4096 to 0 (VRAM-based tiered defaults)
if err := db.migrateV13ToV14(); err != nil {
return fmt.Errorf("migrate v13 to v14: %w", err)
}
version = 14
default:
// If we have a version we don't recognize, just set it to current
// This might happen during development
@@ -480,22 +474,6 @@ func (db *database) migrateV12ToV13() error {
return nil
}
// migrateV13ToV14 changes the default context_length from 4096 to 0.
// When context_length is 0, the ollama server uses VRAM-based tiered defaults.
func (db *database) migrateV13ToV14() error {
_, err := db.conn.Exec(`UPDATE settings SET context_length = 0 WHERE context_length = 4096`)
if err != nil {
return fmt.Errorf("update context_length default: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 14`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(`

View File

@@ -98,43 +98,6 @@ func TestSchemaMigrations(t *testing.T) {
})
}
func TestMigrationV13ToV14ContextLength(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
db, err := newDatabase(dbPath)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
defer db.Close()
_, err = db.conn.Exec("UPDATE settings SET context_length = 4096, schema_version = 13")
if err != nil {
t.Fatalf("failed to seed v13 settings row: %v", err)
}
if err := db.migrate(); err != nil {
t.Fatalf("migration from v13 to v14 failed: %v", err)
}
var contextLength int
if err := db.conn.QueryRow("SELECT context_length FROM settings").Scan(&contextLength); err != nil {
t.Fatalf("failed to read context_length: %v", err)
}
if contextLength != 0 {
t.Fatalf("expected context_length to migrate to 0, got %d", contextLength)
}
version, err := db.getSchemaVersion()
if err != nil {
t.Fatalf("failed to get schema version: %v", err)
}
if version != currentSchemaVersion {
t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version)
}
}
func TestChatDeletionWithCascade(t *testing.T) {
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
tmpDir := t.TempDir()

View File

@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS settings (
agent BOOLEAN NOT NULL DEFAULT 0,
tools BOOLEAN NOT NULL DEFAULT 0,
working_dir TEXT NOT NULL DEFAULT '',
context_length INTEGER NOT NULL DEFAULT 0,
context_length INTEGER NOT NULL DEFAULT 4096,
window_width INTEGER NOT NULL DEFAULT 0,
window_height INTEGER NOT NULL DEFAULT 0,
config_migrated BOOLEAN NOT NULL DEFAULT 0,

View File

@@ -289,12 +289,10 @@ export class InferenceCompute {
}
export class InferenceComputeResponse {
inferenceComputes: InferenceCompute[];
defaultContextLength: number;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute);
this.defaultContextLength = source["defaultContextLength"];
}
convertValues(a: any, classs: any, asMap: boolean = false): any {

View File

@@ -4,6 +4,7 @@ import {
ChatEvent,
DownloadEvent,
ErrorEvent,
InferenceCompute,
InferenceComputeResponse,
ModelCapabilitiesResponse,
Model,
@@ -406,7 +407,7 @@ export async function* pullModel(
}
}
export async function getInferenceCompute(): Promise<InferenceComputeResponse> {
export async function getInferenceCompute(): Promise<InferenceCompute[]> {
const response = await fetch(`${API_BASE}/api/v1/inference-compute`);
if (!response.ok) {
throw new Error(
@@ -415,7 +416,8 @@ export async function getInferenceCompute(): Promise<InferenceComputeResponse> {
}
const data = await response.json();
return new InferenceComputeResponse(data);
const inferenceComputeResponse = new InferenceComputeResponse(data);
return inferenceComputeResponse.inferenceComputes || [];
}
export async function fetchHealth(): Promise<boolean> {

View File

@@ -26,7 +26,6 @@ import {
type CloudStatusResponse,
updateCloudSetting,
updateSettings,
getInferenceCompute,
} from "@/api";
function AnimatedDots() {
@@ -78,13 +77,6 @@ export default function Settings() {
const settings = settingsData?.settings || null;
const { data: inferenceComputeResponse } = useQuery({
queryKey: ["inferenceCompute"],
queryFn: getInferenceCompute,
});
const defaultContextLength = inferenceComputeResponse?.defaultContextLength;
const updateSettingsMutation = useMutation({
mutationFn: updateSettings,
onSuccess: () => {
@@ -212,7 +204,7 @@ export default function Settings() {
Models: "",
Agent: false,
Tools: false,
ContextLength: 0,
ContextLength: 4096,
});
updateSettingsMutation.mutate(defaultSettings);
}
@@ -515,11 +507,13 @@ export default function Settings() {
</Description>
<div className="mt-3">
<Slider
value={settings.ContextLength || defaultContextLength || 0}
value={(() => {
// Otherwise use the settings value
return settings.ContextLength || 4096;
})()}
onChange={(value) => {
handleChange("ContextLength", value);
}}
disabled={!defaultContextLength}
options={[
{ value: 4096, label: "4k" },
{ value: 8192, label: "8k" },

View File

@@ -6,11 +6,10 @@ export interface SliderProps {
value?: number;
onChange?: (value: number) => void;
className?: string;
disabled?: boolean;
}
const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
({ label, options, value = 0, onChange, disabled = false }, ref) => {
({ label, options, value = 0, onChange }, ref) => {
const [selectedValue, setSelectedValue] = React.useState(value);
const [isDragging, setIsDragging] = React.useState(false);
const containerRef = React.useRef<HTMLDivElement>(null);
@@ -21,7 +20,6 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
}, [value]);
const handleClick = (optionValue: number) => {
if (disabled) return;
setSelectedValue(optionValue);
onChange?.(optionValue);
};
@@ -41,7 +39,6 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
};
const handleMouseDown = (e: React.MouseEvent) => {
if (disabled) return;
setIsDragging(true);
e.preventDefault();
};
@@ -80,7 +77,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
}
return (
<div className={`space-y-2 ${disabled ? "opacity-50" : ""}`} ref={ref}>
<div className="space-y-2" ref={ref}>
{label && <label className="text-sm font-medium">{label}</label>}
<div className="relative">
<div className="absolute top-[9px] left-2 right-2 h-1 bg-neutral-200 dark:bg-neutral-700 pointer-events-none rounded-full" />
@@ -91,11 +88,10 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
<button
onClick={() => handleClick(option.value)}
onMouseDown={handleMouseDown}
disabled={disabled}
className={`relative px-3 py-6 -mx-3 -my-6 z-10 ${disabled ? "cursor-not-allowed" : "cursor-pointer"}`}
className="relative px-3 py-6 -mx-3 -my-6 z-10 cursor-pointer"
>
<div className="relative w-5 h-5 flex items-center justify-center">
{selectedValue === option.value && !disabled && (
{selectedValue === option.value && (
<div className="w-4 h-4 bg-white dark:bg-white border border-neutral-400 dark:border-neutral-500 rounded-full cursor-grab active:cursor-grabbing" />
)}
</div>

View File

@@ -28,14 +28,12 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
currentChatId && currentChatId !== "new" ? currentChatId : "",
);
const { data: inferenceComputeResponse } = useQuery({
queryKey: ["inferenceCompute"],
const { data: inferenceComputes = [] } = useQuery({
queryKey: ["inference-compute"],
queryFn: getInferenceCompute,
enabled: !settings.selectedModel, // Only fetch if no model is selected
});
const inferenceComputes = inferenceComputeResponse?.inferenceComputes || [];
const totalVRAM = useMemo(
() => getTotalVRAM(inferenceComputes),
[inferenceComputes],

View File

@@ -45,8 +45,7 @@ type InferenceCompute struct {
}
type InferenceComputeResponse struct {
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
DefaultContextLength int `json:"defaultContextLength"`
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
}
type ModelCapabilitiesResponse struct {

View File

@@ -1420,6 +1420,11 @@ func (s *Server) getSettings(w http.ResponseWriter, r *http.Request) error {
settings.Models = envconfig.Models()
}
// set default context length if not set
if settings.ContextLength == 0 {
settings.ContextLength = 4096
}
// Include current runtime settings
settings.Agent = s.Agent
settings.Tools = s.Tools
@@ -1495,14 +1500,14 @@ func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
defer cancel()
info, err := server.GetInferenceInfo(ctx)
serverInferenceComputes, err := server.GetInferenceComputer(ctx)
if err != nil {
s.log().Error("failed to get inference info", "error", err)
return fmt.Errorf("failed to get inference info: %w", err)
s.log().Error("failed to get inference compute", "error", err)
return fmt.Errorf("failed to get inference compute: %w", err)
}
inferenceComputes := make([]responses.InferenceCompute, len(info.Computes))
for i, ic := range info.Computes {
inferenceComputes := make([]responses.InferenceCompute, len(serverInferenceComputes))
for i, ic := range serverInferenceComputes {
inferenceComputes[i] = responses.InferenceCompute{
Library: ic.Library,
Variant: ic.Variant,
@@ -1514,8 +1519,7 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
}
response := responses.InferenceComputeResponse{
InferenceComputes: inferenceComputes,
DefaultContextLength: info.DefaultContextLength,
InferenceComputes: inferenceComputes,
}
w.Header().Set("Content-Type", "application/json")

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
@@ -83,3 +84,24 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
// signature is <pubkey>:<signature>
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
}
// SignRequest adds a nonce query parameter and an Authorization header with
// an Ed25519 signature to req.
func SignRequest(ctx context.Context, req *http.Request) error {
nonce, err := NewNonce(rand.Reader, 16)
if err != nil {
return err
}
q := req.URL.Query()
q.Set("nonce", nonce)
req.URL.RawQuery = q.Encode()
data := []byte(fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI()))
signature, err := Sign(ctx, data)
if err != nil {
return err
}
req.Header.Set("Authorization", signature)
return nil
}

View File

@@ -1900,6 +1900,21 @@ func runInteractiveTUI(cmd *cobra.Command) {
return
}
if version.Version != "0.0.0" && version.IsOfficialInstall() && version.IsLocalHost(envconfig.Host()) {
if version.HasCachedUpdate() {
fmt.Print("A new version of Ollama is available. Run \"ollama update\" to install.\n\n")
_ = version.ClearCachedUpdate()
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if available, err := version.CheckForUpdate(ctx); err == nil && available {
_ = version.CacheAvailableUpdate()
}
}()
}
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
@@ -1956,10 +1971,6 @@ func runInteractiveTUI(cmd *cobra.Command) {
}
launchIntegration := func(name string) bool {
if err := config.EnsureInstalled(name); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
return true
}
// If not configured or model no longer exists, prompt for model selection
configuredModel := config.IntegrationModel(name)
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
@@ -2321,6 +2332,18 @@ func NewCLI() *cobra.Command {
}
}
updateCmd := &cobra.Command{
Use: "update",
Short: "Update Ollama to the latest version",
Args: cobra.ExactArgs(0),
RunE: func(cmd *cobra.Command, args []string) error {
force, _ := cmd.Flags().GetBool("force")
_ = version.ClearCachedUpdate()
return version.DoUpdate(force)
},
}
updateCmd.Flags().BoolP("force", "f", false, "Force update even if installed via a package manager")
rootCmd.AddCommand(
serveCmd,
createCmd,
@@ -2338,6 +2361,7 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
updateCmd,
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
)

View File

@@ -6,7 +6,6 @@ import (
"os/exec"
"strings"
"github.com/ollama/ollama/envconfig"
"golang.org/x/mod/semver"
)
@@ -33,10 +32,6 @@ func (c *Codex) Run(model string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
"OPENAI_API_KEY=ollama",
)
return cmd.Run()
}

View File

@@ -15,9 +15,8 @@ import (
)
type integration struct {
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
Onboarded bool `json:"onboarded,omitempty"`
Models []string `json:"models"`
Aliases map[string]string `json:"aliases,omitempty"`
}
type config struct {
@@ -140,54 +139,34 @@ func SaveIntegration(appName string, models []string) error {
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
var aliases map[string]string
var onboarded bool
if existing != nil {
if existing != nil && existing.Aliases != nil {
aliases = existing.Aliases
onboarded = existing.Onboarded
}
cfg.Integrations[key] = &integration{
Models: models,
Aliases: aliases,
Onboarded: onboarded,
Models: models,
Aliases: aliases,
}
return save(cfg)
}
// integrationOnboarded marks an integration as onboarded in ollama's config.
func integrationOnboarded(appName string) error {
cfg, err := load()
if err != nil {
return err
}
key := strings.ToLower(appName)
existing := cfg.Integrations[key]
if existing == nil {
existing = &integration{}
}
existing.Onboarded = true
cfg.Integrations[key] = existing
return save(cfg)
}
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
func IntegrationModel(appName string) string {
integrationConfig, err := loadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
return ""
}
return integrationConfig.Models[0]
return ic.Models[0]
}
// IntegrationModels returns all configured models for an integration, or nil.
func IntegrationModels(appName string) []string {
integrationConfig, err := loadIntegration(appName)
if err != nil || len(integrationConfig.Models) == 0 {
ic, err := loadIntegration(appName)
if err != nil || len(ic.Models) == 0 {
return nil
}
return integrationConfig.Models
return ic.Models
}
// LastModel returns the last model that was run, or empty string if none.
@@ -255,12 +234,12 @@ func loadIntegration(appName string) (*integration, error) {
return nil, err
}
integrationConfig, ok := cfg.Integrations[strings.ToLower(appName)]
ic, ok := cfg.Integrations[strings.ToLower(appName)]
if !ok {
return nil, os.ErrNotExist
}
return integrationConfig, nil
return ic, nil
}
func saveAliases(appName string, aliases map[string]string) error {
@@ -293,8 +272,8 @@ func listIntegrations() ([]integration, error) {
}
result := make([]integration, 0, len(cfg.Integrations))
for _, integrationConfig := range cfg.Integrations {
result = append(result, *integrationConfig)
for _, ic := range cfg.Integrations {
result = append(result, *ic)
}
return result, nil

View File

@@ -228,31 +228,6 @@ func IsIntegrationInstalled(name string) bool {
}
}
// AutoInstallable returns true if the integration can be automatically
// installed when not found (e.g. via npm).
func AutoInstallable(name string) bool {
switch strings.ToLower(name) {
case "openclaw", "clawdbot", "moltbot":
return true
default:
return false
}
}
// EnsureInstalled checks if an auto-installable integration is present and
// offers to install it if missing. Returns nil for non-auto-installable
// integrations or when the binary is already on PATH.
func EnsureInstalled(name string) error {
if !AutoInstallable(name) {
return nil
}
if IsIntegrationInstalled(name) {
return nil
}
_, err := ensureOpenclawInstalled()
return err
}
// IsEditorIntegration returns true if the named integration uses multi-model
// selection (implements the Editor interface).
func IsEditorIntegration(name string) bool {
@@ -951,10 +926,6 @@ Examples:
return fmt.Errorf("unknown integration: %s", name)
}
if err := EnsureInstalled(name); err != nil {
return err
}
if modelFlag != "" && IsCloudModelDisabled(cmd.Context(), modelFlag) {
modelFlag = ""
}

View File

@@ -1,287 +1,81 @@
package config
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/url"
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
const defaultGatewayPort = 18789
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
var openclawModelShowTimeout = 5 * time.Second
type Openclaw struct{}
func (c *Openclaw) String() string { return "OpenClaw" }
func (c *Openclaw) Run(model string, args []string) error {
bin, err := ensureOpenclawInstalled()
bin := "openclaw"
if _, err := exec.LookPath(bin); err != nil {
bin = "clawdbot"
if _, err := exec.LookPath(bin); err != nil {
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
}
}
models := []string{model}
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
models = config.Models
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("openclaw", models, func() ([]string, error) {
return selectModels(context.Background(), "openclaw", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
firstLaunch := true
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
firstLaunch = !integrationConfig.Onboarded
}
if firstLaunch {
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
ok, err := confirmPrompt("I understand the risks. Continue?")
if err != nil {
return err
}
if !ok {
return nil
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if !c.onboarded() {
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
// Onboarding not completed: run it (model already set via Edit)
// Use "ollama" as gateway token for simple local access
cmd := exec.Command(bin, "onboard",
"--non-interactive",
"--accept-risk",
"--auth-choice", "skip",
"--gateway-token", "ollama",
"--install-daemon",
"--skip-channels",
"--skip-skills",
)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
}
patchDeviceScopes()
// Onboarding overwrites openclaw.json, so re-apply the model config
// that Edit() wrote before Run() was called.
if err := c.Edit([]string{model}); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
}
return cmd.Run()
}
if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") {
if ensureWebSearchPlugin() {
registerWebSearchPlugin()
}
}
// Onboarding completed: run gateway
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
cmd.Stdin = os.Stdin
if firstLaunch {
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
} else {
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
}
// Capture output to detect "already running" message
var outputBuf bytes.Buffer
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
// When extra args are passed through, run exactly what the user asked for
// after setup and skip the built-in gateway+TUI convenience flow.
if len(args) > 0 {
cmd := exec.Command(bin, args...)
cmd.Env = openclawEnv()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return windowsHint(err)
}
if firstLaunch {
if err := integrationOnboarded("openclaw"); err != nil {
return fmt.Errorf("failed to save onboarding state: %w", err)
}
}
err = cmd.Run()
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
return nil
}
token, port := c.gatewayInfo()
addr := fmt.Sprintf("localhost:%d", port)
// If the gateway is already running (e.g. via the daemon), restart it
// so it picks up any config changes from Edit() above (model, provider, etc.).
if portOpen(addr) {
restart := exec.Command(bin, "daemon", "restart")
restart.Env = openclawEnv()
if err := restart.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
}
if !waitForPort(addr, 10*time.Second) {
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
}
}
// If the gateway isn't running, start it as a background child process.
if !portOpen(addr) {
gw := exec.Command(bin, "gateway", "run", "--force")
gw.Env = openclawEnv()
if err := gw.Start(); err != nil {
return windowsHint(fmt.Errorf("failed to start gateway: %w", err))
}
defer func() {
if gw.Process != nil {
_ = gw.Process.Kill()
_ = gw.Wait()
}
}()
}
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
if !waitForPort(addr, 30*time.Second) {
return windowsHint(fmt.Errorf("gateway did not start on %s", addr))
}
printOpenclawReady(bin, token, port, firstLaunch)
tuiArgs := []string{"tui"}
if firstLaunch {
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
}
tui := exec.Command(bin, tuiArgs...)
tui.Env = openclawEnv()
tui.Stdin = os.Stdin
tui.Stdout = os.Stdout
tui.Stderr = os.Stderr
if err := tui.Run(); err != nil {
return windowsHint(err)
}
if firstLaunch {
if err := integrationOnboarded("openclaw"); err != nil {
return fmt.Errorf("failed to save onboarding state: %w", err)
}
}
return nil
}
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
func (c *Openclaw) gatewayInfo() (token string, port int) {
port = defaultGatewayPort
home, err := os.UserHomeDir()
if err != nil {
return "", port
}
for _, path := range []string{
filepath.Join(home, ".openclaw", "openclaw.json"),
filepath.Join(home, ".clawdbot", "clawdbot.json"),
} {
data, err := os.ReadFile(path)
if err != nil {
continue
}
var config map[string]any
if json.Unmarshal(data, &config) != nil {
continue
}
gw, _ := config["gateway"].(map[string]any)
if p, ok := gw["port"].(float64); ok && p > 0 {
port = int(p)
}
auth, _ := gw["auth"].(map[string]any)
if t, _ := auth["token"].(string); t != "" {
token = t
}
return token, port
}
return "", port
}
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
u := fmt.Sprintf("http://localhost:%d", port)
if token != "" {
u += "/#token=" + url.QueryEscape(token)
}
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
if firstLaunch {
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
} else {
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
}
}
// openclawEnv returns the current environment with provider API keys cleared
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
func openclawEnv() []string {
clear := map[string]bool{
"ANTHROPIC_API_KEY": true,
"ANTHROPIC_OAUTH_TOKEN": true,
"OPENAI_API_KEY": true,
"GEMINI_API_KEY": true,
"MISTRAL_API_KEY": true,
"GROQ_API_KEY": true,
"XAI_API_KEY": true,
"OPENROUTER_API_KEY": true,
}
var env []string
for _, e := range os.Environ() {
key, _, _ := strings.Cut(e, "=")
if !clear[key] {
env = append(env, e)
}
}
return env
}
// portOpen checks if a TCP port is currently accepting connections.
func portOpen(addr string) bool {
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
if err != nil {
return false
}
conn.Close()
return true
}
func waitForPort(addr string, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
if err == nil {
conn.Close()
return true
}
time.Sleep(250 * time.Millisecond)
}
return false
}
func windowsHint(err error) error {
if runtime.GOOS != "windows" {
return err
}
return fmt.Errorf("%w\n\n"+
"OpenClaw runs best on WSL2.\n"+
"Quick setup: wsl --install\n"+
"Guide: https://docs.openclaw.ai/windows", err)
return err
}
// onboarded checks if OpenClaw onboarding wizard was completed
@@ -313,144 +107,6 @@ func (c *Openclaw) onboarded() bool {
return lastRunAt != ""
}
// patchDeviceScopes upgrades the local CLI device's paired scopes to include
// operator.admin. Only patches the local device, not remote ones.
// Best-effort: silently returns on any error.
func patchDeviceScopes() {
home, err := os.UserHomeDir()
if err != nil {
return
}
deviceID := readLocalDeviceID(home)
if deviceID == "" {
return
}
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
data, err := os.ReadFile(path)
if err != nil {
return
}
var devices map[string]map[string]any
if err := json.Unmarshal(data, &devices); err != nil {
return
}
dev, ok := devices[deviceID]
if !ok {
return
}
required := []string{
"operator.read",
"operator.admin",
"operator.approvals",
"operator.pairing",
}
changed := patchScopes(dev, "scopes", required)
if tokens, ok := dev["tokens"].(map[string]any); ok {
for _, tok := range tokens {
if tokenMap, ok := tok.(map[string]any); ok {
if patchScopes(tokenMap, "scopes", required) {
changed = true
}
}
}
}
if !changed {
return
}
out, err := json.MarshalIndent(devices, "", " ")
if err != nil {
return
}
_ = os.WriteFile(path, out, 0o600)
}
// readLocalDeviceID reads the local device ID from openclaw's identity file.
func readLocalDeviceID(home string) string {
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
if err != nil {
return ""
}
var auth map[string]any
if err := json.Unmarshal(data, &auth); err != nil {
return ""
}
id, _ := auth["deviceId"].(string)
return id
}
// patchScopes ensures obj[key] contains all required scopes. Returns true if
// any scopes were added.
func patchScopes(obj map[string]any, key string, required []string) bool {
existing, _ := obj[key].([]any)
have := make(map[string]bool, len(existing))
for _, s := range existing {
if str, ok := s.(string); ok {
have[str] = true
}
}
added := false
for _, s := range required {
if !have[s] {
existing = append(existing, s)
added = true
}
}
if added {
obj[key] = existing
}
return added
}
func ensureOpenclawInstalled() (string, error) {
if _, err := exec.LookPath("openclaw"); err == nil {
return "openclaw", nil
}
if _, err := exec.LookPath("clawdbot"); err == nil {
return "clawdbot", nil
}
if _, err := exec.LookPath("npm"); err != nil {
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
"Install Node.js first:\n" +
" https://nodejs.org/\n\n" +
"Then rerun:\n" +
" ollama launch\n" +
"and select OpenClaw")
}
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("openclaw installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install openclaw: %w", err)
}
if _, err := exec.LookPath("openclaw"); err != nil {
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
return "openclaw", nil
}
func (c *Openclaw) Paths() []string {
home, _ := os.UserHomeDir()
p := filepath.Join(home, ".openclaw", "openclaw.json")
@@ -505,7 +161,8 @@ func (c *Openclaw) Edit(models []string) error {
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
// needed to register provider
ollama["apiKey"] = "ollama-local"
ollama["api"] = "ollama"
// TODO(parthsareen): potentially move to responses
ollama["api"] = "openai-completions"
// Build map of existing models to preserve user customizations
existingModels, _ := ollama["models"].([]any)
@@ -518,13 +175,25 @@ func (c *Openclaw) Edit(models []string) error {
}
}
client, _ := api.ClientFromEnvironment()
var newModels []any
for _, m := range models {
entry, _ := openclawModelConfig(context.Background(), client, m)
for _, model := range models {
entry := map[string]any{
"id": model,
"name": model,
"reasoning": false,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
// TODO(parthsareen): get these values from API
"contextWindow": 131072,
"maxTokens": 16384,
}
// Merge existing fields (user customizations)
if existing, ok := existingByID[m]; ok {
if existing, ok := existingByID[model]; ok {
for k, v := range existing {
if _, isNew := entry[k]; !isNew {
entry[k] = v
@@ -561,213 +230,7 @@ func (c *Openclaw) Edit(models []string) error {
if err != nil {
return err
}
if err := writeWithBackup(configPath, data); err != nil {
return err
}
// Clear any per-session model overrides so the new primary takes effect
// immediately rather than being shadowed by a cached modelOverride.
clearSessionModelOverride(models[0])
return nil
}
// clearSessionModelOverride removes per-session model overrides from the main
// agent session so the global primary model takes effect on the next TUI launch.
func clearSessionModelOverride(primary string) {
home, err := os.UserHomeDir()
if err != nil {
return
}
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
data, err := os.ReadFile(path)
if err != nil {
return
}
var sessions map[string]map[string]any
if json.Unmarshal(data, &sessions) != nil {
return
}
changed := false
for _, sess := range sessions {
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
delete(sess, "modelOverride")
delete(sess, "providerOverride")
sess["model"] = primary
changed = true
}
}
if !changed {
return
}
out, err := json.MarshalIndent(sessions, "", " ")
if err != nil {
return
}
_ = os.WriteFile(path, out, 0o600)
}
const webSearchNpmPackage = "@ollama/openclaw-web-search"
// ensureWebSearchPlugin installs the openclaw-web-search extension into the
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already
// present. Returns true if the extension is available.
func ensureWebSearchPlugin() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil {
return true // already installed
}
npmBin, err := exec.LookPath("npm")
if err != nil {
return false
}
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
return false
}
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
out, err := pack.Output()
if err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
tgzName := strings.TrimSpace(string(out))
tgzPath := filepath.Join(pluginDir, tgzName)
defer os.Remove(tgzPath)
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
if err := tar.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
return true
}
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
// config so the gateway activates it on next start. Best-effort; silently returns
// on any error.
func registerWebSearchPlugin() {
home, err := os.UserHomeDir()
if err != nil {
return
}
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
data, err := os.ReadFile(configPath)
if err != nil {
return
}
var config map[string]any
if json.Unmarshal(data, &config) != nil {
return
}
plugins, _ := config["plugins"].(map[string]any)
if plugins == nil {
plugins = make(map[string]any)
}
entries, _ := plugins["entries"].(map[string]any)
if entries == nil {
entries = make(map[string]any)
}
if _, ok := entries["openclaw-web-search"]; ok {
return // already registered
}
entries["openclaw-web-search"] = map[string]any{"enabled": true}
plugins["entries"] = entries
config["plugins"] = plugins
// Disable the built-in web search since our plugin replaces it.
tools, _ := config["tools"].(map[string]any)
if tools == nil {
tools = make(map[string]any)
}
web, _ := tools["web"].(map[string]any)
if web == nil {
web = make(map[string]any)
}
web["search"] = map[string]any{"enabled": false}
tools["web"] = web
config["tools"] = tools
out, err := json.MarshalIndent(config, "", " ")
if err != nil {
return
}
_ = os.WriteFile(configPath, out, 0o600)
}
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
// The second return value indicates whether the model is a cloud (remote) model.
func openclawModelConfig(ctx context.Context, client *api.Client, modelID string) (map[string]any, bool) {
entry := map[string]any{
"id": modelID,
"name": modelID,
"input": []any{"text"},
"cost": map[string]any{
"input": 0,
"output": 0,
"cacheRead": 0,
"cacheWrite": 0,
},
}
if client == nil {
return entry, false
}
showCtx := ctx
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
showCtx, cancel = context.WithTimeout(ctx, openclawModelShowTimeout)
defer cancel()
}
resp, err := client.Show(showCtx, &api.ShowRequest{Model: modelID})
if err != nil {
return entry, false
}
// Set input types based on vision capability
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
entry["input"] = []any{"text", "image"}
}
// Set reasoning based on thinking capability
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
entry["reasoning"] = true
}
// Cloud models: use hardcoded limits for context/output tokens.
// Capability detection above still applies (vision, thinking).
if resp.RemoteModel != "" {
if l, ok := lookupCloudModelLimit(modelID); ok {
entry["contextWindow"] = l.Context
entry["maxTokens"] = l.Output
}
return entry, true
}
// Extract context window from ModelInfo (local models only)
for key, val := range resp.ModelInfo {
if strings.HasSuffix(key, ".context_length") {
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
entry["contextWindow"] = int(ctxLen)
}
break
}
}
return entry, false
return writeWithBackup(configPath, data)
}
func (c *Openclaw) Models() []string {

View File

@@ -1,21 +1,11 @@
package config
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestOpenclawIntegration(t *testing.T) {
@@ -36,124 +26,6 @@ func TestOpenclawIntegration(t *testing.T) {
})
}
func TestOpenclawRunPassthroughArgs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses a POSIX shell test binary")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
if err := integrationOnboarded("openclaw"); err != nil {
t.Fatal(err)
}
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
}`), 0o644); err != nil {
t.Fatal(err)
}
bin := filepath.Join(tmpDir, "openclaw")
if err := os.WriteFile(bin, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$HOME/invocations.log\"\n"), 0o755); err != nil {
t.Fatal(err)
}
c := &Openclaw{}
if err := c.Run("llama3.2", []string{"gateway", "--someflag"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
data, err := os.ReadFile(filepath.Join(tmpDir, "invocations.log"))
if err != nil {
t.Fatal(err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 1 {
t.Fatalf("expected exactly 1 invocation, got %d: %v", len(lines), lines)
}
if lines[0] != "gateway --someflag" {
t.Fatalf("invocation = %q, want %q", lines[0], "gateway --someflag")
}
}
func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses a POSIX shell test binary")
}
oldHook := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string) (bool, error) {
return true, nil
}
defer func() { DefaultConfirmPrompt = oldHook }()
t.Run("success persists onboarding flag", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
// Mark OpenClaw onboarding complete so Run takes passthrough path directly.
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatal(err)
}
c := &Openclaw{}
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
integrationConfig, err := loadIntegration("openclaw")
if err != nil {
t.Fatalf("loadIntegration() error = %v", err)
}
if !integrationConfig.Onboarded {
t.Fatal("expected onboarding flag to be persisted after successful run")
}
})
t.Run("failure does not persist onboarding flag", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
t.Fatal(err)
}
c := &Openclaw{}
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil {
t.Fatal("expected run failure")
}
integrationConfig, err := loadIntegration("openclaw")
if err == nil && integrationConfig.Onboarded {
t.Fatal("expected onboarding flag to remain unset after failed run")
}
})
}
func TestOpenclawEdit(t *testing.T) {
c := &Openclaw{}
tmpDir := t.TempDir()
@@ -487,16 +359,19 @@ func TestOpenclawEditSchemaFields(t *testing.T) {
modelList := ollama["models"].([]any)
entry := modelList[0].(map[string]any)
// Verify base schema fields (always set regardless of API availability)
if entry["id"] != "llama3.2" {
t.Errorf("id = %v, want llama3.2", entry["id"])
}
if entry["name"] != "llama3.2" {
t.Errorf("name = %v, want llama3.2", entry["name"])
// Verify required schema fields
if entry["reasoning"] != false {
t.Error("reasoning should be false")
}
if entry["input"] == nil {
t.Error("input should be set")
}
if entry["contextWindow"] == nil {
t.Error("contextWindow should be set")
}
if entry["maxTokens"] == nil {
t.Error("maxTokens should be set")
}
cost := entry["cost"].(map[string]any)
if cost["cacheRead"] == nil {
t.Error("cost.cacheRead should be set")
@@ -1001,589 +876,3 @@ func TestOpenclawOnboarded(t *testing.T) {
}
})
}
func TestOpenclawGatewayInfo(t *testing.T) {
c := &Openclaw{}
t.Run("returns defaults when no config exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
token, port := c.gatewayInfo()
if token != "" {
t.Errorf("expected empty token, got %q", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
}
})
t.Run("reads token and port from config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"gateway": {
"port": 9999,
"auth": {"mode": "token", "token": "my-secret"}
}
}`), 0o644)
token, port := c.gatewayInfo()
if token != "my-secret" {
t.Errorf("expected token %q, got %q", "my-secret", token)
}
if port != 9999 {
t.Errorf("expected port 9999, got %d", port)
}
})
t.Run("uses default port when not in config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
"gateway": {"auth": {"token": "tok"}}
}`), 0o644)
token, port := c.gatewayInfo()
if token != "tok" {
t.Errorf("expected token %q, got %q", "tok", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
}
})
t.Run("falls back to legacy clawdbot config", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
legacyDir := filepath.Join(tmpDir, ".clawdbot")
os.MkdirAll(legacyDir, 0o755)
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
"gateway": {"port": 12345, "auth": {"token": "legacy-token"}}
}`), 0o644)
token, port := c.gatewayInfo()
if token != "legacy-token" {
t.Errorf("expected token %q, got %q", "legacy-token", token)
}
if port != 12345 {
t.Errorf("expected port 12345, got %d", port)
}
})
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
token, port := c.gatewayInfo()
if token != "" {
t.Errorf("expected empty token, got %q", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port, got %d", port)
}
})
t.Run("handles missing gateway section", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
os.MkdirAll(configDir, 0o755)
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
token, port := c.gatewayInfo()
if token != "" {
t.Errorf("expected empty token, got %q", token)
}
if port != defaultGatewayPort {
t.Errorf("expected default port, got %d", port)
}
})
}
func TestPrintOpenclawReady(t *testing.T) {
t.Run("includes port in URL", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "", 9999, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "localhost:9999") {
t.Errorf("expected port 9999 in output, got:\n%s", output)
}
if strings.Contains(output, "#token=") {
t.Error("should not include token fragment when token is empty")
}
})
t.Run("URL-escapes token", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "my token&special=chars", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
escaped := url.QueryEscape("my token&special=chars")
if !strings.Contains(output, "#token="+escaped) {
t.Errorf("expected URL-escaped token %q in output, got:\n%s", escaped, output)
}
})
t.Run("simple token is not mangled", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "#token=ollama") {
t.Errorf("expected #token=ollama in output, got:\n%s", output)
}
})
t.Run("includes web UI hint", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "Open the Web UI") {
t.Errorf("expected web UI hint in output, got:\n%s", output)
}
})
t.Run("first launch shows quick start tips", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, true)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
for _, want := range []string{"/help", "channels", "skills", "gateway"} {
if !strings.Contains(output, want) {
t.Errorf("expected %q in first-launch output, got:\n%s", want, output)
}
}
})
t.Run("subsequent launch shows single tip", func(t *testing.T) {
var buf bytes.Buffer
old := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
w.Close()
os.Stderr = old
buf.ReadFrom(r)
output := buf.String()
if !strings.Contains(output, "Tip:") {
t.Errorf("expected single tip line, got:\n%s", output)
}
if strings.Contains(output, "Quick start") {
t.Errorf("should not show quick start on subsequent launch")
}
})
}
func TestOpenclawModelConfig(t *testing.T) {
t.Run("nil client returns base config", func(t *testing.T) {
cfg, _ := openclawModelConfig(context.Background(), nil, "llama3.2")
if cfg["id"] != "llama3.2" {
t.Errorf("id = %v, want llama3.2", cfg["id"])
}
if cfg["name"] != "llama3.2" {
t.Errorf("name = %v, want llama3.2", cfg["name"])
}
if cfg["cost"] == nil {
t.Error("cost should be set")
}
// Should not have capability fields without API
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set without API")
}
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set without API")
}
})
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{"llama.context_length":4096}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "llava:7b")
input, ok := cfg["input"].([]any)
if !ok || len(input) != 2 {
t.Errorf("input = %v, want [text image]", cfg["input"])
}
})
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
input, ok := cfg["input"].([]any)
if !ok || len(input) != 1 {
t.Errorf("input = %v, want [text]", cfg["input"])
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set for non-thinking model")
}
})
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "qwq")
if cfg["reasoning"] != true {
t.Error("expected reasoning = true for thinking model")
}
})
t.Run("extracts context window from model info", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
if cfg["contextWindow"] != 131072 {
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
}
})
t.Run("handles all capabilities together", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "qwen3-vision")
input, ok := cfg["input"].([]any)
if !ok || len(input) != 2 {
t.Errorf("input = %v, want [text image]", cfg["input"])
}
if cfg["reasoning"] != true {
t.Error("expected reasoning = true")
}
if cfg["contextWindow"] != 32768 {
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
}
})
t.Run("returns base config when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "missing-model")
if cfg["id"] != "missing-model" {
t.Errorf("id = %v, want missing-model", cfg["id"])
}
// Should still have input (default)
if cfg["input"] == nil {
t.Error("input should always be set")
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set when show fails")
}
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set when show fails")
}
})
t.Run("times out slow show and returns base config", func(t *testing.T) {
oldTimeout := openclawModelShowTimeout
openclawModelShowTimeout = 50 * time.Millisecond
t.Cleanup(func() { openclawModelShowTimeout = oldTimeout })
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
time.Sleep(300 * time.Millisecond)
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{"llama.context_length":4096}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
start := time.Now()
cfg, _ := openclawModelConfig(context.Background(), client, "slow-model")
elapsed := time.Since(start)
if elapsed >= 250*time.Millisecond {
t.Fatalf("openclawModelConfig took too long: %v", elapsed)
}
if cfg["id"] != "slow-model" {
t.Errorf("id = %v, want slow-model", cfg["id"])
}
if _, ok := cfg["reasoning"]; ok {
t.Error("reasoning should not be set on timeout")
}
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set on timeout")
}
})
t.Run("skips zero context length", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, _ := openclawModelConfig(context.Background(), client, "test-model")
if _, ok := cfg["contextWindow"]; ok {
t.Error("contextWindow should not be set for zero value")
}
})
t.Run("cloud model uses hardcoded limits", func(t *testing.T) {
// Use a model name that's in cloudModelLimits and make the server
// report it as a remote/cloud model
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.5"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.5:cloud")
if !isCloud {
t.Error("expected isCloud = true for cloud model")
}
if cfg["contextWindow"] != 204_800 {
t.Errorf("contextWindow = %v, want 204800", cfg["contextWindow"])
}
if cfg["maxTokens"] != 128_000 {
t.Errorf("maxTokens = %v, want 128000", cfg["maxTokens"])
}
})
t.Run("cloud model with vision capability gets image input", func(t *testing.T) {
// Regression test: cloud models must not skip capability detection.
// A cloud model that reports vision capability should have input: [text, image].
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{},"remote_model":"qwen3-vl"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwen3-vl:235b-cloud")
if !isCloud {
t.Error("expected isCloud = true for cloud vision model")
}
input, ok := cfg["input"].([]any)
if !ok || len(input) != 2 {
t.Errorf("input = %v, want [text image] for cloud vision model", cfg["input"])
}
})
t.Run("cloud model with thinking capability gets reasoning flag", func(t *testing.T) {
// Regression test: cloud models must not skip capability detection.
// A cloud model that reports thinking capability should have reasoning: true.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{},"remote_model":"qwq-cloud"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwq:cloud")
if !isCloud {
t.Error("expected isCloud = true for cloud thinking model")
}
if cfg["reasoning"] != true {
t.Error("expected reasoning = true for cloud thinking model")
}
})
}
func TestIntegrationOnboarded(t *testing.T) {
t.Run("returns false when not set", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
integrationConfig, err := loadIntegration("openclaw")
if err == nil && integrationConfig.Onboarded {
t.Error("expected false for fresh config")
}
})
t.Run("returns true after integrationOnboarded", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
if err := integrationOnboarded("openclaw"); err != nil {
t.Fatal(err)
}
integrationConfig, err := loadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded {
t.Error("expected true after integrationOnboarded")
}
})
t.Run("is case insensitive", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
if err := integrationOnboarded("OpenClaw"); err != nil {
t.Fatal(err)
}
integrationConfig, err := loadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded {
t.Error("expected true when set with different case")
}
})
t.Run("preserves existing integration data", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
if err := SaveIntegration("openclaw", []string{"llama3.2", "mistral"}); err != nil {
t.Fatal(err)
}
if err := integrationOnboarded("openclaw"); err != nil {
t.Fatal(err)
}
// Verify onboarded is set
integrationConfig, err := loadIntegration("openclaw")
if err != nil || !integrationConfig.Onboarded {
t.Error("expected true after integrationOnboarded")
}
// Verify models are preserved
model := IntegrationModel("openclaw")
if model != "llama3.2" {
t.Errorf("expected first model llama3.2, got %q", model)
}
})
}

View File

@@ -10,11 +10,10 @@ import (
// ANSI escape sequences for terminal formatting.
const (
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
ansiYellow = "\033[33m"
ansiBold = "\033[1m"
ansiReset = "\033[0m"
ansiGray = "\033[37m"
ansiGreen = "\033[32m"
)
// ErrCancelled is returned when the user cancels a selection.

View File

@@ -524,7 +524,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case "enter", " ":
item := m.items[m.cursor]
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) {
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
return m, nil
}
@@ -555,12 +555,6 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
item := m.items[m.cursor]
if item.integration != "" || item.isRunModel {
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
if config.AutoInstallable(item.integration) {
// Auto-installable: select to trigger install flow
m.selected = true
m.quitting = true
return m, tea.Quit
}
return m, nil
}
if item.integration != "" && config.IsEditorIntegration(item.integration) {
@@ -624,11 +618,7 @@ func (m model) View() string {
var modelSuffix string
if item.integration != "" {
if !isInstalled {
if config.AutoInstallable(item.integration) {
title += " " + notInstalledStyle.Render("(install)")
} else {
title += " " + notInstalledStyle.Render("(not installed)")
}
title += " " + notInstalledStyle.Render("(not installed)")
} else if m.cursor == i {
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
modelSuffix = " " + modelStyle.Render("("+mdl+")")
@@ -644,9 +634,7 @@ func (m model) View() string {
desc := item.description
if !isInstalled && item.integration != "" && m.cursor == i {
if config.AutoInstallable(item.integration) {
desc = "Press enter to install"
} else if hint := config.IntegrationInstallHint(item.integration); hint != "" {
if hint := config.IntegrationInstallHint(item.integration); hint != "" {
desc = hint
} else {
desc = "not installed"

View File

@@ -4,65 +4,47 @@ title: OpenClaw
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
## Quick start
## Install
Install [OpenClaw](https://openclaw.ai/)
```bash
npm install -g openclaw@latest
```
Then run the onboarding wizard:
```bash
openclaw onboard --install-daemon
```
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
## Usage with Ollama
### Quick setup
```bash
ollama launch openclaw
```
Ollama handles everything automatically:
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
2. **Security** — On the first launch, a security notice explains the risks of tool access
3. **Model** — Pick a model from the selector (local or cloud)
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
## Configure without launching
This configures OpenClaw to use Ollama and starts the gateway.
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
To change the model without starting the gateway and TUI:
```bash
To configure without launching:
```shell
ollama launch openclaw --config
```
To use a specific model directly:
## Recommended Models
```bash
ollama launch openclaw --model kimi-k2.5:cloud
```
If the gateway is already running, it restarts automatically to pick up the new model.
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
- `glm-5:cloud` — Reasoning and code generation
**Local models:**
- `glm-4.7-flash` — Reasoning and code generation locally (~25 GB VRAM)
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
## Connect messaging apps
```bash
openclaw configure --section channels
```
Link WhatsApp, Telegram, Slack, Discord, or iMessage to chat with your local models from anywhere.
## Stopping the gateway
```bash
openclaw gateway stop
```
- `qwen3-coder`
- `glm-4.7`
- `gpt-oss:20b`
- `gpt-oss:120b`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).

1
go.mod
View File

@@ -26,7 +26,6 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/klauspost/compress v1.18.3
github.com/mattn/go-runewidth v0.0.16
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

4
go.sum
View File

@@ -122,6 +122,7 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
@@ -149,9 +150,8 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=

View File

@@ -11,7 +11,6 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -497,17 +496,6 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
func ResponsesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader("Content-Encoding") == "zstd" {
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
return
}
defer reader.Close()
c.Request.Body = io.NopCloser(reader)
c.Request.Header.Del("Content-Encoding")
}
var req openai.ResponsesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))

View File

@@ -14,7 +14,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -1239,102 +1238,3 @@ func TestImageEditsMiddleware(t *testing.T) {
})
}
}
func zstdCompress(t *testing.T, data []byte) []byte {
t.Helper()
var buf bytes.Buffer
w, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(data); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
return buf.Bytes()
}
func TestResponsesMiddlewareZstd(t *testing.T) {
tests := []struct {
name string
body string
useZstd bool
oversized bool
wantCode int
wantModel string
wantMessage string
}{
{
name: "plain JSON",
body: `{"model": "test-model", "input": "Hello"}`,
wantCode: http.StatusOK,
wantModel: "test-model",
wantMessage: "Hello",
},
{
name: "zstd compressed",
body: `{"model": "test-model", "input": "Hello"}`,
useZstd: true,
wantCode: http.StatusOK,
wantModel: "test-model",
wantMessage: "Hello",
},
{
name: "zstd over max decompressed size",
oversized: true,
useZstd: true,
wantCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedRequest *api.ChatRequest
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
c.Status(http.StatusOK)
})
var bodyReader io.Reader
if tt.oversized {
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
} else if tt.useZstd {
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
} else {
bodyReader = strings.NewReader(tt.body)
}
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
req.Header.Set("Content-Type", "application/json")
if tt.useZstd || tt.oversized {
req.Header.Set("Content-Encoding", "zstd")
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.wantCode {
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
}
if tt.wantCode != http.StatusOK {
return
}
if capturedRequest == nil {
t.Fatal("expected captured request, got nil")
}
if capturedRequest.Model != tt.wantModel {
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
}
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
}
})
}
}

View File

@@ -2,10 +2,6 @@
# This script installs Ollama on Linux and macOS.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
# Wrap script in main function so that a truncated partial download doesn't end
# up executing half a script.
main() {
set -eu
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
@@ -450,6 +446,3 @@ fi
status "NVIDIA GPU ready."
install_success
}
main

190
version/update.go Normal file
View File

@@ -0,0 +1,190 @@
package version
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/auth"
)
var updateCheckURLBase = "https://ollama.com"
// CheckForUpdate calls the ollama.com update API and reports whether a
// newer version is available.
func CheckForUpdate(ctx context.Context) (bool, error) {
requestURL, err := url.Parse(updateCheckURLBase + "/api/update")
if err != nil {
return false, fmt.Errorf("parse update URL: %w", err)
}
query := requestURL.Query()
query.Add("os", runtime.GOOS)
query.Add("arch", runtime.GOARCH)
query.Add("version", Version)
requestURL.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil {
return false, fmt.Errorf("create request: %w", err)
}
_ = auth.SignRequest(ctx, req)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, fmt.Errorf("update check request: %w", err)
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK, nil
}
func cacheFilePath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "update"), nil
}
// CacheAvailableUpdate creates the update marker file.
func CacheAvailableUpdate() error {
path, err := cacheFilePath()
if err != nil {
return err
}
f, err := os.Create(path)
if err != nil {
return err
}
return f.Close()
}
// HasCachedUpdate reports whether a non-stale update marker exists.
func HasCachedUpdate() bool {
path, err := cacheFilePath()
if err != nil {
return false
}
fi, err := os.Stat(path)
if err != nil {
return false
}
return time.Since(fi.ModTime()) <= 24*time.Hour
}
// ClearCachedUpdate removes the update marker file.
func ClearCachedUpdate() error {
path, err := cacheFilePath()
if err != nil {
return err
}
err = os.Remove(path)
if os.IsNotExist(err) {
return nil
}
return err
}
func IsOfficialInstall() bool {
exe, err := os.Executable()
if err != nil {
return false
}
exe, err = filepath.EvalSymlinks(exe)
if err != nil {
return false
}
switch runtime.GOOS {
case "windows":
localAppData := os.Getenv("LOCALAPPDATA")
if localAppData == "" {
return false
}
return strings.HasPrefix(strings.ToLower(exe), strings.ToLower(filepath.Join(localAppData, "Programs", "Ollama")+string(filepath.Separator)))
case "darwin":
return strings.HasPrefix(exe, "/Applications/Ollama.app/")
default:
dir := filepath.Dir(exe)
return dir == "/usr/local/bin" || dir == "/usr/bin" || dir == "/bin"
}
}
// DoUpdate downloads and runs the platform-appropriate install script.
func DoUpdate(force bool) error {
if !force && !IsOfficialInstall() {
return fmt.Errorf("ollama appears to be installed through a package manager. Please update it using your package manager")
}
var scriptURL, tmpPattern, shell string
switch runtime.GOOS {
case "windows":
scriptURL = "https://ollama.com/install.ps1"
tmpPattern = "ollama-install-*.ps1"
shell = "powershell"
default:
scriptURL = "https://ollama.com/install.sh"
tmpPattern = "ollama-install-*.sh"
shell = "sh"
}
resp, err := http.Get(scriptURL)
if err != nil {
return fmt.Errorf("download install script: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download install script: status %d", resp.StatusCode)
}
tmpFile, err := os.CreateTemp("", tmpPattern)
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
defer os.Remove(tmpFile.Name())
if _, err := io.Copy(tmpFile, resp.Body); err != nil {
tmpFile.Close()
return fmt.Errorf("write install script: %w", err)
}
tmpFile.Close()
cmd := exec.Command(shell, tmpFile.Name())
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// IsLocalHost reports whether the configured Ollama host points to the
// local machine.
func IsLocalHost(host *url.URL) bool {
hostname := host.Hostname()
switch hostname {
case "", "127.0.0.1", "localhost", "::1", "0.0.0.0":
return true
}
if ip := net.ParseIP(hostname); ip != nil {
return ip.IsLoopback()
}
return false
}

146
version/update_test.go Normal file
View File

@@ -0,0 +1,146 @@
package version
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"testing"
"time"
)
func setHome(t *testing.T, dir string) {
t.Helper()
if runtime.GOOS == "windows" {
t.Setenv("USERPROFILE", dir)
} else {
t.Setenv("HOME", dir)
}
}
func TestCheckForUpdate(t *testing.T) {
t.Run("update available", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("os") == "" || r.URL.Query().Get("arch") == "" || r.URL.Query().Get("version") == "" {
t.Error("missing expected query parameters")
}
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
old := updateCheckURLBase
updateCheckURLBase = ts.URL
defer func() { updateCheckURLBase = old }()
available, err := CheckForUpdate(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !available {
t.Fatal("expected update to be available")
}
})
t.Run("up to date", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer ts.Close()
old := updateCheckURLBase
updateCheckURLBase = ts.URL
defer func() { updateCheckURLBase = old }()
available, err := CheckForUpdate(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if available {
t.Fatal("expected no update available")
}
})
t.Run("network error", func(t *testing.T) {
old := updateCheckURLBase
updateCheckURLBase = "http://localhost:1"
defer func() { updateCheckURLBase = old }()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := CheckForUpdate(ctx)
if err == nil {
t.Fatal("expected error for unreachable server")
}
})
}
func TestCacheRoundTrip(t *testing.T) {
tmp := t.TempDir()
setHome(t, tmp)
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
if err := CacheAvailableUpdate(); err != nil {
t.Fatalf("cache write: %v", err)
}
if !HasCachedUpdate() {
t.Fatal("expected cached update to be present")
}
if err := ClearCachedUpdate(); err != nil {
t.Fatalf("cache clear: %v", err)
}
if HasCachedUpdate() {
t.Fatal("expected no cached update after clear")
}
}
func TestHasCachedUpdateStale(t *testing.T) {
tmp := t.TempDir()
setHome(t, tmp)
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
if err := CacheAvailableUpdate(); err != nil {
t.Fatalf("cache write: %v", err)
}
// Backdate the file to make it stale
path := filepath.Join(tmp, ".ollama", "update")
staleTime := time.Now().Add(-25 * time.Hour)
os.Chtimes(path, staleTime, staleTime)
if HasCachedUpdate() {
t.Fatal("expected no cached update for stale file")
}
}
func TestIsLocalHost(t *testing.T) {
tests := []struct {
host string
local bool
}{
{"http://127.0.0.1:11434", true},
{"http://localhost:11434", true},
{"http://[::1]:11434", true},
{"http://0.0.0.0:11434", true},
{"http://remote.example.com:11434", false},
{"http://192.168.1.100:11434", false},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
u, err := url.Parse(tt.host)
if err != nil {
t.Fatalf("parse URL: %v", err)
}
if got := IsLocalHost(u); got != tt.local {
t.Errorf("IsLocalHost(%s) = %v, want %v", tt.host, got, tt.local)
}
})
}
}

View File

@@ -16,10 +16,10 @@ import (
)
type Function struct {
Name string
ReturnType string
Params string
ParamNames []string
Name string
ReturnType string
Params string
ParamNames []string
NeedsARM64Guard bool
}
@@ -29,6 +29,11 @@ func findHeaders(directory string) ([]string, error) {
if err != nil {
return err
}
// Private headers contain C++ implementation helpers and are not part of
// the C API surface; parsing them can produce invalid wrapper signatures.
if d.IsDir() && d.Name() == "private" {
return fs.SkipDir
}
if !d.IsDir() && strings.HasSuffix(path, ".h") {
headers = append(headers, path)
}
@@ -194,10 +199,10 @@ func parseFunctions(content string) []Function {
needsGuard := needsARM64Guard(funcName, returnType, params)
functions = append(functions, Function{
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
NeedsARM64Guard: needsGuard,
})
}

View File

@@ -20,6 +20,8 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL;
mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL;
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
@@ -49,7 +51,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
#endif
@@ -67,7 +69,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
#endif
@@ -123,6 +125,7 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
int (*mlx_disable_compile_ptr)(void) = NULL;
int (*mlx_enable_compile_ptr)(void) = NULL;
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
int (*mlx_cuda_is_available_ptr)(bool* res) = NULL;
mlx_device (*mlx_device_new_ptr)(void) = NULL;
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
@@ -133,6 +136,16 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL;
int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL;
mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL;
int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL;
int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL;
int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL;
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
@@ -263,7 +276,6 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL;
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL;
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
@@ -658,6 +670,16 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
return -1;
}
mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed");
if (mlx_array_new_data_managed_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
return -1;
}
mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload");
if (mlx_array_new_data_managed_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
return -1;
}
mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
if (mlx_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
@@ -1141,6 +1163,11 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
return -1;
}
mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available");
if (mlx_cuda_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
return -1;
}
mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
if (mlx_device_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
@@ -1191,6 +1218,56 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
return -1;
}
mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available");
if (mlx_device_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
return -1;
}
mlx_device_count_ptr = dlsym(handle, "mlx_device_count");
if (mlx_device_count_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
return -1;
}
mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new");
if (mlx_device_info_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
return -1;
}
mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get");
if (mlx_device_info_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
return -1;
}
mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free");
if (mlx_device_info_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
return -1;
}
mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key");
if (mlx_device_info_has_key_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
return -1;
}
mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string");
if (mlx_device_info_is_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
return -1;
}
mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string");
if (mlx_device_info_get_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
return -1;
}
mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size");
if (mlx_device_info_get_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
return -1;
}
mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys");
if (mlx_device_info_get_keys_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
return -1;
}
mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
if (mlx_distributed_all_gather_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
@@ -1841,11 +1918,6 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
return -1;
}
mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info");
if (mlx_metal_device_info_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n");
return -1;
}
mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
if (mlx_metal_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
@@ -3528,6 +3600,14 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt
return mlx_array_new_data_ptr(data, shape, dim, dtype);
}
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) {
return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor);
}
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) {
return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor);
}
int mlx_array_set(mlx_array* arr, const mlx_array src) {
return mlx_array_set_ptr(arr, src);
}
@@ -3644,7 +3724,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) {
return mlx_array_item_float64_ptr(res, arr);
}
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) {
return mlx_array_item_complex64_ptr(res, arr);
}
@@ -3704,7 +3784,7 @@ const double* mlx_array_data_float64(const mlx_array arr) {
return mlx_array_data_float64_ptr(arr);
}
const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) {
return mlx_array_data_complex64_ptr(arr);
}
@@ -3916,6 +3996,10 @@ int mlx_set_compile_mode(mlx_compile_mode mode) {
return mlx_set_compile_mode_ptr(mode);
}
int mlx_cuda_is_available(bool* res) {
return mlx_cuda_is_available_ptr(res);
}
mlx_device mlx_device_new(void) {
return mlx_device_new_ptr();
}
@@ -3956,6 +4040,46 @@ int mlx_set_default_device(mlx_device dev) {
return mlx_set_default_device_ptr(dev);
}
int mlx_device_is_available(bool* avail, mlx_device dev) {
return mlx_device_is_available_ptr(avail, dev);
}
int mlx_device_count(int* count, mlx_device_type type) {
return mlx_device_count_ptr(count, type);
}
mlx_device_info mlx_device_info_new(void) {
return mlx_device_info_new_ptr();
}
int mlx_device_info_get(mlx_device_info* info, mlx_device dev) {
return mlx_device_info_get_ptr(info, dev);
}
int mlx_device_info_free(mlx_device_info info) {
return mlx_device_info_free_ptr(info);
}
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) {
return mlx_device_info_has_key_ptr(exists, info, key);
}
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) {
return mlx_device_info_is_string_ptr(is_string, info, key);
}
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) {
return mlx_device_info_get_string_ptr(value, info, key);
}
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) {
return mlx_device_info_get_size_ptr(value, info, key);
}
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) {
return mlx_device_info_get_keys_ptr(keys, info);
}
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
return mlx_distributed_all_gather_ptr(res, x, group, S);
}
@@ -4476,10 +4600,6 @@ int mlx_set_wired_limit(size_t* res, size_t limit) {
return mlx_set_wired_limit_ptr(res, limit);
}
mlx_metal_device_info_t mlx_metal_device_info(void) {
return mlx_metal_device_info_ptr();
}
int mlx_metal_is_available(bool* res) {
return mlx_metal_is_available_ptr(res);
}

View File

@@ -26,6 +26,8 @@
#undef mlx_array_new_double
#undef mlx_array_new_complex
#undef mlx_array_new_data
#undef mlx_array_new_data_managed
#undef mlx_array_new_data_managed_payload
#undef mlx_array_set
#undef mlx_array_set_bool
#undef mlx_array_set_int
@@ -121,6 +123,7 @@
#undef mlx_disable_compile
#undef mlx_enable_compile
#undef mlx_set_compile_mode
#undef mlx_cuda_is_available
#undef mlx_device_new
#undef mlx_device_new_type
#undef mlx_device_free
@@ -131,6 +134,16 @@
#undef mlx_device_get_type
#undef mlx_get_default_device
#undef mlx_set_default_device
#undef mlx_device_is_available
#undef mlx_device_count
#undef mlx_device_info_new
#undef mlx_device_info_get
#undef mlx_device_info_free
#undef mlx_device_info_has_key
#undef mlx_device_info_is_string
#undef mlx_device_info_get_string
#undef mlx_device_info_get_size
#undef mlx_device_info_get_keys
#undef mlx_distributed_all_gather
#undef mlx_distributed_all_max
#undef mlx_distributed_all_min
@@ -261,7 +274,6 @@
#undef mlx_set_cache_limit
#undef mlx_set_memory_limit
#undef mlx_set_wired_limit
#undef mlx_metal_device_info
#undef mlx_metal_is_available
#undef mlx_metal_start_capture
#undef mlx_metal_stop_capture
@@ -602,6 +614,8 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val);
extern mlx_array (*mlx_array_new_double_ptr)(double val);
extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val);
extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype);
extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src);
extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val);
extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val);
@@ -631,7 +645,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr);
extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr);
extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr);
extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr);
extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr);
extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr);
#endif
@@ -649,7 +663,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr);
extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr);
extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr);
extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr);
extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr);
#endif
@@ -705,6 +719,7 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id);
extern int (*mlx_disable_compile_ptr)(void);
extern int (*mlx_enable_compile_ptr)(void);
extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode);
extern int (*mlx_cuda_is_available_ptr)(bool* res);
extern mlx_device (*mlx_device_new_ptr)(void);
extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index);
extern int (*mlx_device_free_ptr)(mlx_device dev);
@@ -715,6 +730,16 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev);
extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev);
extern int (*mlx_get_default_device_ptr)(mlx_device* dev);
extern int (*mlx_set_default_device_ptr)(mlx_device dev);
extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev);
extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type);
extern mlx_device_info (*mlx_device_info_new_ptr)(void);
extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev);
extern int (*mlx_device_info_free_ptr)(mlx_device_info info);
extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key);
extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info);
extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
@@ -845,7 +870,6 @@ extern int (*mlx_reset_peak_memory_ptr)(void);
extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit);
extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit);
extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit);
extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void);
extern int (*mlx_metal_is_available_ptr)(bool* res);
extern int (*mlx_metal_start_capture_ptr)(const char* path);
extern int (*mlx_metal_stop_capture_ptr)(void);
@@ -1202,6 +1226,10 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val);
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype);
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
int mlx_array_set(mlx_array* arr, const mlx_array src);
int mlx_array_set_bool(mlx_array* arr, bool val);
@@ -1260,7 +1288,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr);
int mlx_array_item_float64(double* res, const mlx_array arr);
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
@@ -1292,7 +1320,7 @@ const float* mlx_array_data_float32(const mlx_array arr);
const double* mlx_array_data_float64(const mlx_array arr);
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* mlx_array_data_float16(const mlx_array arr);
@@ -1400,6 +1428,8 @@ int mlx_enable_compile(void);
int mlx_set_compile_mode(mlx_compile_mode mode);
int mlx_cuda_is_available(bool* res);
mlx_device mlx_device_new(void);
mlx_device mlx_device_new_type(mlx_device_type type, int index);
@@ -1420,6 +1450,26 @@ int mlx_get_default_device(mlx_device* dev);
int mlx_set_default_device(mlx_device dev);
int mlx_device_is_available(bool* avail, mlx_device dev);
int mlx_device_count(int* count, mlx_device_type type);
mlx_device_info mlx_device_info_new(void);
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
int mlx_device_info_free(mlx_device_info info);
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key);
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key);
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key);
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key);
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
@@ -1680,8 +1730,6 @@ int mlx_set_memory_limit(size_t* res, size_t limit);
int mlx_set_wired_limit(size_t* res, size_t limit);
mlx_metal_device_info_t mlx_metal_device_info(void);
int mlx_metal_is_available(bool* res);
int mlx_metal_start_capture(const char* path);

View File

@@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
FetchContent_Declare(
mlx-c

View File

@@ -18,9 +18,7 @@
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
handle->ctx = (void*) DLOPEN(path);
if (handle->ctx == NULL) {
return 1;
}
CHECK(handle->ctx != NULL);
return 0;
}

View File

@@ -55,30 +55,6 @@ func tryLoadFromDir(dir string) bool {
return false
}
// tryLoadByName attempts to load the library using just its name,
// allowing the system to use rpath, LD_LIBRARY_PATH, or standard search paths.
// Returns true if the library was successfully loaded.
func tryLoadByName() bool {
libraryName := "libmlxc.dylib"
if runtime.GOOS == "linux" {
libraryName = "libmlxc.so"
}
cPath := C.CString(libraryName)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
return false
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
C.mlx_dynamic_unload(&handle)
return false
}
return true
}
func init() {
switch runtime.GOOS {
case "darwin":
@@ -97,11 +73,6 @@ func init() {
}
}
// Try loading via rpath/standard library search
if tryLoadByName() {
return
}
// Build search paths: executable directory, then build directories
var searchDirs []string
if exe, err := os.Executable(); err == nil {

View File

@@ -22,6 +22,19 @@ mlx_array (*mlx_array_new_data_)(
const int* shape,
int dim,
mlx_dtype dtype) = NULL;
mlx_array (*mlx_array_new_data_managed_)(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void (*dtor)(void*)) = NULL;
mlx_array (*mlx_array_new_data_managed_payload_)(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void* payload,
void (*dtor)(void*)) = NULL;
int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL;
int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL;
int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL;
@@ -56,7 +69,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL;
const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL;
@@ -70,7 +83,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL;
const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL;
const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL;
const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL;
const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL;
const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL;
int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL;
@@ -94,10 +107,11 @@ int (*mlx_closure_apply_)(
mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL;
int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -136,11 +150,12 @@ int (*mlx_closure_value_and_grad_apply_)(
const mlx_vector_array input) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL;
int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array)) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array)) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -161,12 +176,13 @@ int (*mlx_closure_custom_apply_)(
const mlx_vector_array input_2) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL;
int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -189,12 +205,13 @@ int (*mlx_closure_custom_jvp_apply_)(
size_t input_2_num) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL;
int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(
int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -228,6 +245,7 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL;
int (*mlx_disable_compile_)(void) = NULL;
int (*mlx_enable_compile_)(void) = NULL;
int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL;
int (*mlx_cuda_is_available_)(bool* res) = NULL;
mlx_device (*mlx_device_new_)(void) = NULL;
mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL;
int (*mlx_device_free_)(mlx_device dev) = NULL;
@@ -238,11 +256,28 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL;
int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL;
int (*mlx_get_default_device_)(mlx_device* dev) = NULL;
int (*mlx_set_default_device_)(mlx_device dev) = NULL;
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL;
int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL;
mlx_device_info (*mlx_device_info_new_)(void) = NULL;
int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL;
int (*mlx_device_info_free_)(mlx_device_info info) = NULL;
int (*mlx_device_info_has_key_)(
bool* exists,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_is_string_)(
bool* is_string,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_string_)(
const char** value,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_size_)(
size_t* value,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL;
int (*mlx_distributed_all_gather_)(
mlx_array* res,
const mlx_array x,
@@ -288,6 +323,11 @@ int (*mlx_distributed_sum_scatter_)(
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
void (*mlx_set_error_handler_)(
mlx_error_handler_func handler,
void* data,
@@ -450,6 +490,16 @@ int (*mlx_fast_rope_)(
int offset,
const mlx_array freqs /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_fast_rope_dynamic_)(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
const mlx_array offset,
const mlx_array freqs /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_fast_scaled_dot_product_attention_)(
mlx_array* res,
const mlx_array queries,
@@ -560,14 +610,6 @@ int (*mlx_fft_rfftn_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
int (*mlx_load_reader_)(
mlx_array* res,
mlx_io_reader in_stream,
@@ -593,6 +635,14 @@ int (*mlx_save_safetensors_)(
const char* file,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata) = NULL;
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
int (*mlx_linalg_cholesky_)(
mlx_array* res,
const mlx_array a,
@@ -733,7 +783,6 @@ int (*mlx_reset_peak_memory_)(void) = NULL;
int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL;
int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL;
int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL;
mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL;
int (*mlx_metal_is_available_)(bool* res) = NULL;
int (*mlx_metal_start_capture_)(const char* path) = NULL;
int (*mlx_metal_stop_capture_)(void) = NULL;
@@ -1162,6 +1211,14 @@ int (*mlx_gather_)(
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s) = NULL;
int (*mlx_gather_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
int axis,
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s) = NULL;
int (*mlx_gather_mm_)(
mlx_array* res,
const mlx_array a,
@@ -1483,6 +1540,15 @@ int (*mlx_put_along_axis_)(
const mlx_array values,
int axis,
const mlx_stream s) = NULL;
int (*mlx_qqmm_)(
mlx_array* res,
const mlx_array x,
const mlx_array w,
const mlx_array w_scales /* may be null */,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_stream s) = NULL;
int (*mlx_quantize_)(
mlx_vector_array* res,
const mlx_array w,
@@ -1566,6 +1632,13 @@ int (*mlx_scatter_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_)(
mlx_array* res,
const mlx_array a,
@@ -1574,6 +1647,13 @@ int (*mlx_scatter_add_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_axis_)(
mlx_array* res,
const mlx_array a,
@@ -1589,6 +1669,13 @@ int (*mlx_scatter_max_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_max_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_min_)(
mlx_array* res,
const mlx_array a,
@@ -1597,6 +1684,13 @@ int (*mlx_scatter_min_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_min_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_prod_)(
mlx_array* res,
const mlx_array a,
@@ -1605,6 +1699,13 @@ int (*mlx_scatter_prod_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_prod_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_segmented_mm_)(
mlx_array* res,
const mlx_array a,
@@ -2028,22 +2129,6 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL;
int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL;
const char * (*mlx_string_data_)(mlx_string str) = NULL;
int (*mlx_string_free_)(mlx_string str) = NULL;
int (*mlx_detail_vmap_replace_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num) = NULL;
int (*mlx_detail_vmap_trace_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num) = NULL;
int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL;
int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL;
int (*mlx_custom_function_)(
@@ -2074,6 +2159,22 @@ int (*mlx_vjp_)(
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array cotangents) = NULL;
int (*mlx_detail_vmap_replace_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num) = NULL;
int (*mlx_detail_vmap_trace_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num) = NULL;
mlx_vector_array (*mlx_vector_array_new_)(void) = NULL;
int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL;
@@ -2166,6 +2267,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_array_new_double);
CHECK_LOAD(handle, mlx_array_new_complex);
CHECK_LOAD(handle, mlx_array_new_data);
CHECK_LOAD(handle, mlx_array_new_data_managed);
CHECK_LOAD(handle, mlx_array_new_data_managed_payload);
CHECK_LOAD(handle, mlx_array_set);
CHECK_LOAD(handle, mlx_array_set_bool);
CHECK_LOAD(handle, mlx_array_set_int);
@@ -2261,6 +2364,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_disable_compile);
CHECK_LOAD(handle, mlx_enable_compile);
CHECK_LOAD(handle, mlx_set_compile_mode);
CHECK_LOAD(handle, mlx_cuda_is_available);
CHECK_LOAD(handle, mlx_device_new);
CHECK_LOAD(handle, mlx_device_new_type);
CHECK_LOAD(handle, mlx_device_free);
@@ -2271,11 +2375,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_device_get_type);
CHECK_LOAD(handle, mlx_get_default_device);
CHECK_LOAD(handle, mlx_set_default_device);
CHECK_LOAD(handle, mlx_distributed_group_rank);
CHECK_LOAD(handle, mlx_distributed_group_size);
CHECK_LOAD(handle, mlx_distributed_group_split);
CHECK_LOAD(handle, mlx_distributed_is_available);
CHECK_LOAD(handle, mlx_distributed_init);
CHECK_LOAD(handle, mlx_device_is_available);
CHECK_LOAD(handle, mlx_device_count);
CHECK_LOAD(handle, mlx_device_info_new);
CHECK_LOAD(handle, mlx_device_info_get);
CHECK_LOAD(handle, mlx_device_info_free);
CHECK_LOAD(handle, mlx_device_info_has_key);
CHECK_LOAD(handle, mlx_device_info_is_string);
CHECK_LOAD(handle, mlx_device_info_get_string);
CHECK_LOAD(handle, mlx_device_info_get_size);
CHECK_LOAD(handle, mlx_device_info_get_keys);
CHECK_LOAD(handle, mlx_distributed_all_gather);
CHECK_LOAD(handle, mlx_distributed_all_max);
CHECK_LOAD(handle, mlx_distributed_all_min);
@@ -2284,6 +2393,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_distributed_recv_like);
CHECK_LOAD(handle, mlx_distributed_send);
CHECK_LOAD(handle, mlx_distributed_sum_scatter);
CHECK_LOAD(handle, mlx_distributed_group_rank);
CHECK_LOAD(handle, mlx_distributed_group_size);
CHECK_LOAD(handle, mlx_distributed_group_split);
CHECK_LOAD(handle, mlx_distributed_is_available);
CHECK_LOAD(handle, mlx_distributed_init);
CHECK_LOAD(handle, mlx_set_error_handler);
CHECK_LOAD(handle, _mlx_error);
CHECK_LOAD(handle, mlx_export_function);
@@ -2325,6 +2439,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_fast_metal_kernel_apply);
CHECK_LOAD(handle, mlx_fast_rms_norm);
CHECK_LOAD(handle, mlx_fast_rope);
CHECK_LOAD(handle, mlx_fast_rope_dynamic);
CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention);
CHECK_LOAD(handle, mlx_fft_fft);
CHECK_LOAD(handle, mlx_fft_fft2);
@@ -2340,14 +2455,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_fft_rfft);
CHECK_LOAD(handle, mlx_fft_rfft2);
CHECK_LOAD(handle, mlx_fft_rfftn);
CHECK_LOAD(handle, mlx_io_reader_new);
CHECK_LOAD(handle, mlx_io_reader_descriptor);
CHECK_LOAD(handle, mlx_io_reader_tostring);
CHECK_LOAD(handle, mlx_io_reader_free);
CHECK_LOAD(handle, mlx_io_writer_new);
CHECK_LOAD(handle, mlx_io_writer_descriptor);
CHECK_LOAD(handle, mlx_io_writer_tostring);
CHECK_LOAD(handle, mlx_io_writer_free);
CHECK_LOAD(handle, mlx_load_reader);
CHECK_LOAD(handle, mlx_load);
CHECK_LOAD(handle, mlx_load_safetensors_reader);
@@ -2356,6 +2463,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_save);
CHECK_LOAD(handle, mlx_save_safetensors_writer);
CHECK_LOAD(handle, mlx_save_safetensors);
CHECK_LOAD(handle, mlx_io_reader_new);
CHECK_LOAD(handle, mlx_io_reader_descriptor);
CHECK_LOAD(handle, mlx_io_reader_tostring);
CHECK_LOAD(handle, mlx_io_reader_free);
CHECK_LOAD(handle, mlx_io_writer_new);
CHECK_LOAD(handle, mlx_io_writer_descriptor);
CHECK_LOAD(handle, mlx_io_writer_tostring);
CHECK_LOAD(handle, mlx_io_writer_free);
CHECK_LOAD(handle, mlx_linalg_cholesky);
CHECK_LOAD(handle, mlx_linalg_cholesky_inv);
CHECK_LOAD(handle, mlx_linalg_cross);
@@ -2400,7 +2515,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_set_cache_limit);
CHECK_LOAD(handle, mlx_set_memory_limit);
CHECK_LOAD(handle, mlx_set_wired_limit);
CHECK_LOAD(handle, mlx_metal_device_info);
CHECK_LOAD(handle, mlx_metal_is_available);
CHECK_LOAD(handle, mlx_metal_start_capture);
CHECK_LOAD(handle, mlx_metal_stop_capture);
@@ -2486,6 +2600,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_full);
CHECK_LOAD(handle, mlx_full_like);
CHECK_LOAD(handle, mlx_gather);
CHECK_LOAD(handle, mlx_gather_single);
CHECK_LOAD(handle, mlx_gather_mm);
CHECK_LOAD(handle, mlx_gather_qmm);
CHECK_LOAD(handle, mlx_greater);
@@ -2550,6 +2665,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_prod_axis);
CHECK_LOAD(handle, mlx_prod);
CHECK_LOAD(handle, mlx_put_along_axis);
CHECK_LOAD(handle, mlx_qqmm);
CHECK_LOAD(handle, mlx_quantize);
CHECK_LOAD(handle, mlx_quantized_matmul);
CHECK_LOAD(handle, mlx_radians);
@@ -2566,11 +2682,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_round);
CHECK_LOAD(handle, mlx_rsqrt);
CHECK_LOAD(handle, mlx_scatter);
CHECK_LOAD(handle, mlx_scatter_single);
CHECK_LOAD(handle, mlx_scatter_add);
CHECK_LOAD(handle, mlx_scatter_add_single);
CHECK_LOAD(handle, mlx_scatter_add_axis);
CHECK_LOAD(handle, mlx_scatter_max);
CHECK_LOAD(handle, mlx_scatter_max_single);
CHECK_LOAD(handle, mlx_scatter_min);
CHECK_LOAD(handle, mlx_scatter_min_single);
CHECK_LOAD(handle, mlx_scatter_prod);
CHECK_LOAD(handle, mlx_scatter_prod_single);
CHECK_LOAD(handle, mlx_segmented_mm);
CHECK_LOAD(handle, mlx_sigmoid);
CHECK_LOAD(handle, mlx_sign);
@@ -2665,8 +2786,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_string_set);
CHECK_LOAD(handle, mlx_string_data);
CHECK_LOAD(handle, mlx_string_free);
CHECK_LOAD(handle, mlx_detail_vmap_replace);
CHECK_LOAD(handle, mlx_detail_vmap_trace);
CHECK_LOAD(handle, mlx_async_eval);
CHECK_LOAD(handle, mlx_checkpoint);
CHECK_LOAD(handle, mlx_custom_function);
@@ -2675,6 +2794,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_jvp);
CHECK_LOAD(handle, mlx_value_and_grad);
CHECK_LOAD(handle, mlx_vjp);
CHECK_LOAD(handle, mlx_detail_vmap_replace);
CHECK_LOAD(handle, mlx_detail_vmap_trace);
CHECK_LOAD(handle, mlx_vector_array_new);
CHECK_LOAD(handle, mlx_vector_array_set);
CHECK_LOAD(handle, mlx_vector_array_free);

View File

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,10 @@
#define MLX_GENERATED_H
#include "dynamic.h"
{{ range .Functions }}
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
{{- end }}
#include "mlx/c/mlx.h"
{{ range .Functions }}
#undef {{ .Name }}

View File

@@ -8,10 +8,10 @@ import (
"log/slog"
"sync"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/tokenizer"
)
// Model is the interface that model implementations must satisfy.

View File

@@ -7,6 +7,7 @@ import (
"errors"
"log/slog"
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
@@ -125,5 +126,13 @@ func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
return ""
}
return flushValidUTF8Prefix(b)
if text := b.String(); utf8.ValidString(text) {
b.Reset()
return text
} else if b.Len() >= utf8.UTFMax {
b.Reset()
return text
}
return ""
}

View File

@@ -12,12 +12,12 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
)
type Request struct {

View File

@@ -1,47 +0,0 @@
package mlxrunner
import (
"bytes"
"unicode/utf8"
)
// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix
// currently buffered, leaving any incomplete trailing bytes in place.
func flushValidUTF8Prefix(b *bytes.Buffer) string {
data := b.Bytes()
if len(data) == 0 {
return ""
}
prefix := validUTF8PrefixLen(data)
if prefix == 0 {
return ""
}
text := string(data[:prefix])
b.Next(prefix)
return text
}
func validUTF8PrefixLen(data []byte) int {
i := 0
prefix := 0
for i < len(data) {
r, size := utf8.DecodeRune(data[i:])
if r == utf8.RuneError && size == 1 {
if !utf8.FullRune(data[i:]) {
break
}
// Invalid UTF-8 byte; consume one byte to guarantee forward progress.
i++
prefix = i
continue
}
i += size
prefix = i
}
return prefix
}

View File

@@ -1,46 +0,0 @@
package mlxrunner
import (
"bytes"
"testing"
)
func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) {
var b bytes.Buffer
b.Write([]byte{0xE3, 0x81})
if got := flushValidUTF8Prefix(&b); got != "" {
t.Fatalf("first flush = %q, want empty", got)
}
b.Write([]byte{0x93, 0xE3})
if got := flushValidUTF8Prefix(&b); got != "こ" {
t.Fatalf("second flush = %q, want %q", got, "こ")
}
if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) {
t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3})
}
b.Write([]byte{0x82, 0x93})
if got := flushValidUTF8Prefix(&b); got != "ん" {
t.Fatalf("third flush = %q, want %q", got, "ん")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after third flush: %d", b.Len())
}
}
func TestFlushValidUTF8Prefix_ValidText(t *testing.T) {
var b bytes.Buffer
b.WriteString("hello 世界")
if got := flushValidUTF8Prefix(&b); got != "hello 世界" {
t.Fatalf("flush = %q, want %q", got, "hello 世界")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after flush: %d", b.Len())
}
}

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

View File

@@ -9,12 +9,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {

View File

@@ -1,108 +0,0 @@
//go:build mlx
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
//
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
// - GPT-2 byte-level encoding (OpenAI tiktoken)
// - HuggingFace tokenizer.json pretokenizer patterns
// - SentencePiece ▁-style space handling
package tokenizer
import "regexp"
// TokenizerType identifies the tokenization algorithm
type TokenizerType int
const (
TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
TokenizerSentencePiece // SentencePiece with ▁ for spaces
)
// Vocabulary holds the tokenizer vocabulary and merges
type Vocabulary struct {
Values []string
Reverse map[string]int32
Merges map[string]int
BOS int32
EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has <eos> and <end_of_turn>)
PAD int32 // Padding token (often <|endoftext|> or <pad>)
AddBOS bool
AddEOS bool
// Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
byteTokens [256]int32
}
// Tokenizer handles BPE and SentencePiece tokenization
type Tokenizer struct {
vocab *Vocabulary
pretokenizer *regexp.Regexp
specialTokens map[string]int32 // Special tokens for direct lookup
sortedSpecialTokens []string // Special tokens sorted by length, longest first
typ TokenizerType // Algorithm type
}
// Precomputed GPT-2 byte-level encoding table
// Maps byte values to their encoded rune equivalents
var byteToRune [256]rune
func init() {
for b := 0; b < 256; b++ {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
byteToRune[b] = r
}
}
// VocabSize returns the vocabulary size
func (t *Tokenizer) VocabSize() int {
return len(t.vocab.Values)
}
// BOS returns the beginning of sequence token ID
func (t *Tokenizer) BOS() int32 {
return t.vocab.BOS
}
// EOS returns the first end of sequence token ID (for backwards compatibility)
func (t *Tokenizer) EOS() int32 {
if len(t.vocab.EOS) > 0 {
return t.vocab.EOS[0]
}
return -1
}
// EOSTokens returns all end of sequence token IDs
func (t *Tokenizer) EOSTokens() []int32 {
return t.vocab.EOS
}
// PAD returns the padding token ID, or -1 if not set
func (t *Tokenizer) PAD() int32 {
return t.vocab.PAD
}
// IsEOS returns true if the token ID is an end of sequence token
func (t *Tokenizer) IsEOS(id int32) bool {
for _, eos := range t.vocab.EOS {
if id == eos {
return true
}
}
return false
}
// GetSpecialToken returns the token ID for a special token string
func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
id, ok := t.specialTokens[name]
return id, ok
}

View File

@@ -1,251 +0,0 @@
//go:build mlx
package tokenizer
import (
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
var (
benchmarkSinkIDs []int32
benchmarkSinkStr string
benchmarkSinkTok *Tokenizer
)
const benchmarkWordPieceJSON = `{
"model": {
"type": "WordPiece",
"vocab": {
"[UNK]": 0,
"hello": 1,
"##world": 2,
"##ly": 3,
"##hello": 4
}
},
"added_tokens": []
}`
const benchmarkSentencePieceJSON = `{
"model": {
"type": "BPE",
"vocab": {
"\u2581": 0,
"h": 1,
"e": 2,
"l": 3,
"o": 4,
"w": 5,
"r": 6,
"d": 7,
"<0x0A>": 8
},
"merges": []
},
"decoder": {
"type": "Sequence",
"decoders": [
{
"type": "Replace",
"pattern": {
"String": "\u2581"
}
}
]
},
"added_tokens": []
}`
func benchmarkMiniLlamaPath(tb testing.TB) string {
tb.Helper()
_, filename, _, ok := runtime.Caller(0)
if !ok {
tb.Fatal("failed to resolve benchmark file path")
}
return filepath.Join(filepath.Dir(filename), "..", "imagegen", "tokenizer", "testdata", "mini_llama.json")
}
func benchmarkLoadMiniLlama(tb testing.TB) *Tokenizer {
tb.Helper()
data := benchmarkLoadMiniLlamaBytes(tb)
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load mini llama tokenizer: %v", err)
}
return tok
}
func benchmarkLoadMiniLlamaBytes(tb testing.TB) []byte {
tb.Helper()
data, err := os.ReadFile(benchmarkMiniLlamaPath(tb))
if err != nil {
tb.Fatalf("failed to read mini llama tokenizer: %v", err)
}
return data
}
func benchmarkLoadFromBytes(tb testing.TB, data []byte) *Tokenizer {
tb.Helper()
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load tokenizer from bytes: %v", err)
}
return tok
}
func BenchmarkTokenizerEncodeBPE(b *testing.B) {
tok := benchmarkLoadMiniLlama(b)
inputs := []struct {
name string
text string
}{
{name: "short", text: "Hello, world!"},
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
{name: "long_sequential", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 80)},
{name: "long_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
{name: "huge_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)},
{name: "special_tokens", text: "<|begin_of_text|>system\nYou are concise.<|end_of_text|>"},
}
for _, input := range inputs {
b.Run(input.name, func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(input.text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(input.text, false)
}
})
}
}
func BenchmarkTokenizerDecodeBPE(b *testing.B) {
tok := benchmarkLoadMiniLlama(b)
inputs := []struct {
name string
text string
}{
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
{name: "long", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
}
for _, input := range inputs {
ids := tok.Encode(input.text, false)
b.Run(input.name, func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(input.text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
})
}
}
func BenchmarkTokenizerLoadFromBytes(b *testing.B) {
data := benchmarkLoadMiniLlamaBytes(b)
config := &TokenizerConfig{
TokenizerConfigJSON: []byte(`{
"bos_token": {"content": "<|begin_of_text|>"},
"eos_token": {"content": "<|end_of_text|>"},
"add_bos_token": true
}`),
GenerationConfigJSON: []byte(`{"bos_token_id": 128000, "eos_token_id": 128001}`),
}
b.Run("without_config", func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
tok, err := LoadFromBytes(data)
if err != nil {
b.Fatalf("LoadFromBytes failed: %v", err)
}
benchmarkSinkTok = tok
}
})
b.Run("with_config", func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
tok, err := LoadFromBytesWithConfig(data, config)
if err != nil {
b.Fatalf("LoadFromBytesWithConfig failed: %v", err)
}
benchmarkSinkTok = tok
}
})
}
func BenchmarkTokenizerEncodeWordPiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
text := strings.Repeat("helloworldly", 16)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(text, false)
}
}
func BenchmarkTokenizerDecodeWordPiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
text := strings.Repeat("helloworldly", 16)
ids := tok.Encode(text, false)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
}
func BenchmarkTokenizerEncodeSentencePiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
text := strings.Repeat("hello world\n", 64)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkIDs = tok.Encode(text, false)
}
}
func BenchmarkTokenizerDecodeSentencePiece(b *testing.B) {
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
text := strings.Repeat("hello world\n", 64)
ids := tok.Encode(text, false)
b.ReportAllocs()
b.SetBytes(int64(len(text)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkSinkStr = tok.Decode(ids)
}
}

View File

@@ -1,175 +0,0 @@
//go:build mlx
package tokenizer
import "container/heap"
type bpeMergeNode struct {
prev int
next int
token string
}
type bpePair struct {
left int
right int
rank int
value string
}
type bpePairHeap []*bpePair
func (h bpePairHeap) Len() int { return len(h) }
func (h bpePairHeap) Less(i, j int) bool {
return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left)
}
func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *bpePairHeap) Push(x any) {
*h = append(*h, x.(*bpePair))
}
func (h *bpePairHeap) Pop() any {
old := *h
n := len(old)
item := old[n-1]
*h = old[:n-1]
return item
}
// encodeBPEMerge encodes using BPE merge algorithm.
// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go:
// merge the lowest-rank valid pair, then only recheck adjacent pairs.
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
runes := []rune(encoded)
if len(runes) == 0 {
return ids
}
nodes := make([]bpeMergeNode, len(runes))
for i := range runes {
nodes[i] = bpeMergeNode{
prev: i - 1,
next: i + 1,
token: string(runes[i]),
}
}
pairwise := func(left, right int) *bpePair {
if left < 0 || right >= len(nodes) {
return nil
}
if nodes[left].token == "" || nodes[right].token == "" {
return nil
}
leftToken, rightToken := nodes[left].token, nodes[right].token
rank, ok := t.vocab.Merges[leftToken+" "+rightToken]
if !ok {
return nil
}
value := leftToken + rightToken
if _, ok := t.vocab.Reverse[value]; !ok {
return nil
}
return &bpePair{
left: left,
right: right,
rank: rank,
value: value,
}
}
pairs := bpePairHeap{}
heap.Init(&pairs)
for i := 0; i < len(runes)-1; i++ {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(&pairs, pair)
}
}
for pairs.Len() > 0 {
pair := heap.Pop(&pairs).(*bpePair)
left, right := nodes[pair.left], nodes[pair.right]
if left.token == "" || right.token == "" {
continue
}
if left.next != pair.right || right.prev != pair.left {
continue
}
if left.token+right.token != pair.value {
continue
}
nodes[pair.left].token = pair.value
nodes[pair.right].token = ""
nodes[pair.left].next = right.next
if right.next < len(nodes) {
nodes[right.next].prev = pair.left
}
if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil {
heap.Push(&pairs, pair)
}
if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil {
heap.Push(&pairs, pair)
}
}
for _, node := range nodes {
if node.token == "" {
continue
}
if id, ok := t.vocab.Reverse[node.token]; ok {
ids = append(ids, id)
continue
}
ids = t.appendByteFallback(ids, node.token)
}
return ids
}
func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 {
if t.typ == TokenizerBPE {
for _, r := range token {
if b, ok := decodeByteLevelRune(r); ok {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
}
return ids
}
// SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens.
for _, b := range []byte(token) {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
return ids
}
func decodeByteLevelRune(r rune) (byte, bool) {
switch {
case r >= 0x00 && r <= 0xFF:
return byte(r), true
case r == 0x0100:
return 0x00, true
case r == 0x0143:
return 0x00ad, true
case r > 0x0100 && r <= 0x0120:
return byte(r - 0x0100), true
case r > 0x0120 && r <= 0x0142:
return byte(r - 0x00a2), true
default:
return 0, false
}
}

View File

@@ -1,137 +0,0 @@
//go:build mlx
package tokenizer
import (
"runtime"
"strings"
"testing"
)
func equalIDs(a, b []int32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func TestEncodeRoundtripMiniLlama(t *testing.T) {
tok := benchmarkLoadMiniLlama(t)
inputs := []string{
"",
"hello",
"hello world",
" hello world ",
"don't we'll they're",
"1234567890",
"こんにちは世界",
"Hello 世界",
"func main() {}",
"<|begin_of_text|>system\nYou are concise.<|end_of_text|>",
strings.Repeat("The quick brown fox jumps over the lazy dog. ", 32),
}
for _, input := range inputs {
ids := tok.Encode(input, false)
got := tok.Decode(ids)
if got != input {
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
}
}
}
func TestSplitBySpecialTokensGreedyLongest(t *testing.T) {
data := []byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0, "b": 1},
"merges": []
},
"added_tokens": [
{"id": 2, "content": "<tag>", "special": true},
{"id": 3, "content": "<tag>x", "special": true}
]
}`)
tok, err := LoadFromBytes(data)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
input := "a<tag>xb"
want := []string{"a", "<tag>x", "b"}
got := tok.splitBySpecialTokens(input)
if len(got) != len(want) {
t.Fatalf("split length mismatch: got %v want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
}
}
}
func TestSplitBySpecialTokensFallbackWithoutCache(t *testing.T) {
data := []byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0, "b": 1},
"merges": []
},
"added_tokens": [
{"id": 2, "content": "<tag>", "special": true},
{"id": 3, "content": "<tag>x", "special": true}
]
}`)
tok, err := LoadFromBytes(data)
if err != nil {
t.Fatalf("failed to load tokenizer: %v", err)
}
input := "a<tag>xb"
want := []string{"a", "<tag>x", "b"}
// Simulate construction outside loader path where cache is not set.
tok.sortedSpecialTokens = nil
got := tok.splitBySpecialTokens(input)
if len(got) != len(want) {
t.Fatalf("split length mismatch: got %v want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
}
}
}
func TestEncodeDeterministicAcrossGOMAXPROCS(t *testing.T) {
tok := benchmarkLoadMiniLlama(t)
input := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)
prev := runtime.GOMAXPROCS(0)
defer runtime.GOMAXPROCS(prev)
runtime.GOMAXPROCS(1)
seq := tok.Encode(input, false)
if prev < 2 {
runtime.GOMAXPROCS(2)
} else {
runtime.GOMAXPROCS(prev)
}
par := tok.Encode(input, false)
if !equalIDs(seq, par) {
t.Fatalf("encode mismatch between sequential and parallel paths: seq=%d par=%d", len(seq), len(par))
}
}

View File

@@ -1,56 +0,0 @@
//go:build mlx
package tokenizer
import (
"strconv"
"strings"
)
// Decode converts token IDs back to text
func (t *Tokenizer) Decode(ids []int32) string {
var sb strings.Builder
for _, id := range ids {
if int(id) >= len(t.vocab.Values) {
continue
}
token := t.vocab.Values[id]
switch t.typ {
case TokenizerSentencePiece:
// SentencePiece style: replace ▁ with space, decode byte tokens
token = strings.ReplaceAll(token, "▁", " ")
// Handle byte fallback tokens like <0x0D>
if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
sb.WriteByte(byte(v))
continue
}
}
sb.WriteString(token)
default:
// GPT-2 BPE style: decode byte-level encoding
for _, r := range token {
switch {
case r == 0x0100:
// Mirror GGML tokenizer behavior for NULL byte.
// 0x00 is omitted during decode.
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// Write as byte, not UTF-8 encoded rune
sb.WriteByte(byte(r))
}
}
}
return sb.String()
}

View File

@@ -1,289 +0,0 @@
//go:build mlx
package tokenizer
import (
"runtime"
"sort"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
const (
encodeParallelMinInputBytes = 4 * 1024
encodeParallelMinChunksPerWorker = 8
)
type tokenMatch struct {
start int
end int
}
type encodeChunk struct {
text string
isSpecial bool
}
// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
func isNonNewlineWhitespace(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if r == '\n' || r == '\r' {
return false
}
if !unicode.IsSpace(r) {
return false
}
}
return true
}
// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
func (t *Tokenizer) splitBySpecialTokens(s string) []string {
if len(t.specialTokens) == 0 {
return []string{s}
}
tokens := t.sortedSpecialTokens
if len(tokens) == 0 {
// Fallback for tokenizers constructed outside the loaders.
tokens = make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
}
var result []string
remaining := s
for len(remaining) > 0 {
found := false
for _, tok := range tokens {
if strings.HasPrefix(remaining, tok) {
result = append(result, tok)
remaining = remaining[len(tok):]
found = true
break
}
}
if !found {
// Find next special token position
nextPos := len(remaining)
for _, tok := range tokens {
if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
nextPos = idx
}
}
if nextPos > 0 {
result = append(result, remaining[:nextPos])
}
remaining = remaining[nextPos:]
}
}
return result
}
func adjustWhitespaceBoundary(part string, curr, next *tokenMatch) {
m := part[curr.start:curr.end]
nextText := part[next.start:next.end]
if !isNonNewlineWhitespace(m) || len(nextText) == 0 {
return
}
firstRune, _ := utf8.DecodeRuneInString(nextText)
if !unicode.IsLetter(firstRune) {
return
}
lastSpaceStart := curr.end
for j := curr.end; j > curr.start; {
r, size := utf8.DecodeLastRuneInString(part[curr.start:j])
if unicode.IsSpace(r) {
lastSpaceStart = j - size
break
}
j -= size
}
if lastSpaceStart > curr.start {
curr.end = lastSpaceStart
next.start = lastSpaceStart
} else {
next.start = curr.start
curr.end = curr.start
}
}
func (t *Tokenizer) forEachPartChunk(part string, fn func(encodeChunk)) {
if _, ok := t.specialTokens[part]; ok {
fn(encodeChunk{text: part, isSpecial: true})
return
}
if t.pretokenizer == nil {
fn(encodeChunk{text: part, isSpecial: false})
return
}
re := t.pretokenizer
offset := 0
loc := re.FindStringIndex(part[offset:])
if loc == nil {
return
}
curr := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
offset += loc[1]
for {
loc = re.FindStringIndex(part[offset:])
if loc == nil {
if curr.end > curr.start {
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
}
return
}
next := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
offset += loc[1]
adjustWhitespaceBoundary(part, &curr, &next)
if curr.end > curr.start {
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
}
curr = next
}
}
func (t *Tokenizer) appendEncodedChunk(ids []int32, c encodeChunk) []int32 {
if c.isSpecial {
if id, ok := t.specialTokens[c.text]; ok {
return append(ids, id)
}
return ids
}
return t.encodeChunkInto(c.text, ids)
}
// Encode tokenizes text to token IDs.
// Parallel encoding is used only for very large inputs with enough chunks per worker.
func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
// First: split by special tokens
parts := t.splitBySpecialTokens(s)
// Fast path: encode sequentially without materializing chunk slices.
if len(s) < encodeParallelMinInputBytes {
var ids []int32
for _, part := range parts {
t.forEachPartChunk(part, func(c encodeChunk) {
ids = t.appendEncodedChunk(ids, c)
})
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// For large inputs collect chunks to enable parallel processing.
var allChunks []encodeChunk
for _, part := range parts {
t.forEachPartChunk(part, func(c encodeChunk) {
allChunks = append(allChunks, c)
})
}
// Encode chunks. Use the parallel path only when the chunk count is
// large enough to amortize goroutine/synchronization overhead.
useParallel := true
numWorkers := runtime.GOMAXPROCS(0)
if numWorkers > len(allChunks) {
numWorkers = len(allChunks)
}
if numWorkers < 2 || len(allChunks) < numWorkers*encodeParallelMinChunksPerWorker {
useParallel = false
}
var ids []int32
if !useParallel {
for _, c := range allChunks {
ids = t.appendEncodedChunk(ids, c)
}
} else {
chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
results := make([][]int32, numWorkers)
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
start := i * chunksPer
end := start + chunksPer
if end > len(allChunks) {
end = len(allChunks)
}
if start >= end {
continue
}
wg.Add(1)
go func(i int, chunks []encodeChunk) {
defer wg.Done()
var r []int32
for _, c := range chunks {
r = t.appendEncodedChunk(r, c)
}
results[i] = r
}(i, allChunks[start:end])
}
wg.Wait()
for _, r := range results {
ids = append(ids, r...)
}
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// encodeChunkInto appends encoded tokens to ids and returns the extended slice.
// Uses BPE merge algorithm for both BPE and SentencePiece tokenization.
func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
if s == "" {
return ids
}
// Apply encoding transformation
// SentencePiece: replace space with ▁
// BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
var encoded string
if t.typ == TokenizerSentencePiece {
encoded = strings.ReplaceAll(s, " ", "▁")
} else {
var sb strings.Builder
sb.Grow(len(s) * 2)
for i := 0; i < len(s); i++ {
sb.WriteRune(byteToRune[s[i]])
}
encoded = sb.String()
}
// Fast path: check if entire chunk is a single token
if id, ok := t.vocab.Reverse[encoded]; ok {
return append(ids, id)
}
return t.encodeBPEMerge(encoded, ids)
}

View File

@@ -1,207 +0,0 @@
//go:build mlx
package tokenizer
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func llama32GGMLFixturePath(tb testing.TB, file string) string {
tb.Helper()
_, filename, _, ok := runtime.Caller(0)
if !ok {
tb.Fatal("failed to resolve test file path")
}
return filepath.Join(filepath.Dir(filename), "..", "..", "tokenizer", "testdata", "llama3.2", file)
}
func loadLlama32FromGGMLFixture(tb testing.TB) *Tokenizer {
tb.Helper()
f, err := os.Open(llama32GGMLFixturePath(tb, "encoder.json"))
if err != nil {
tb.Fatalf("failed to open encoder.json: %v", err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
tb.Fatalf("failed to decode encoder.json: %v", err)
}
type addedToken struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
}
var addedTokens []addedToken
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
if _, ok := vocab[token]; !ok {
id := int32(len(vocab))
vocab[token] = id
addedTokens = append(addedTokens, addedToken{ID: id, Content: token, Special: true})
}
}
mf, err := os.Open(llama32GGMLFixturePath(tb, "vocab.bpe"))
if err != nil {
tb.Fatalf("failed to open vocab.bpe: %v", err)
}
defer mf.Close()
var merges []string
scanner := bufio.NewScanner(mf)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") {
continue
}
line = strings.TrimSpace(line)
if line != "" {
merges = append(merges, line)
}
}
if err := scanner.Err(); err != nil {
tb.Fatalf("failed to read vocab.bpe: %v", err)
}
payload := struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges []string `json:"merges"`
} `json:"model"`
PreTokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
AddedTokens []addedToken `json:"added_tokens"`
}{}
payload.Model.Type = "BPE"
payload.Model.Vocab = vocab
payload.Model.Merges = merges
payload.PreTokenizer.Type = "Sequence"
payload.PreTokenizer.Pretokenizers = []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}{
{
Type: "Split",
Pattern: struct {
Regex string `json:"Regex"`
}{
Regex: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
},
},
}
payload.AddedTokens = addedTokens
data, err := json.Marshal(payload)
if err != nil {
tb.Fatalf("failed to marshal synthetic tokenizer.json: %v", err)
}
tok, err := LoadFromBytes(data)
if err != nil {
tb.Fatalf("failed to load tokenizer from fixture data: %v", err)
}
return tok
}
func TestGGMLLlamaKnownEncodings(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := map[string][]int32{
"hello world": {15339, 1917},
"hello <|end_of_text|>": {15339, 220, 128001},
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
}
for input, want := range cases {
got := tok.Encode(input, false)
if !equalIDs(got, want) {
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
}
}
}
func TestGGMLLlamaRepeatedZeros(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := map[int][]int32{
1: {15},
2: {410},
3: {931},
4: {931, 15},
5: {931, 410},
6: {931, 931},
7: {931, 931, 15},
8: {931, 931, 410},
9: {931, 931, 931},
10: {931, 931, 931, 15},
11: {931, 931, 931, 410},
12: {931, 931, 931, 931},
13: {931, 931, 931, 931, 15},
14: {931, 931, 931, 931, 410},
15: {931, 931, 931, 931, 931},
16: {931, 931, 931, 931, 931, 15},
17: {931, 931, 931, 931, 931, 410},
}
for n, want := range cases {
input := strings.Repeat("0", n)
got := tok.Encode(input, false)
if !equalIDs(got, want) {
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
}
}
}
func TestGGMLLlamaRoundtripAndByteBehavior(t *testing.T) {
tok := loadLlama32FromGGMLFixture(t)
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
}
for _, input := range cases {
ids := tok.Encode(input, false)
got := tok.Decode(ids)
if got != input {
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
}
}
// Match GGML tokenizer behavior: 0x00 is omitted when decoding.
ids := tok.Encode(string(rune(0x00)), false)
got := tok.Decode(ids)
if got != "" {
t.Fatalf("expected empty decode for 0x00, got %q (ids=%v)", got, ids)
}
}

View File

@@ -1,458 +0,0 @@
//go:build mlx
package tokenizer
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
)
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
type TokenizerConfig struct {
TokenizerConfigJSON []byte // tokenizer_config.json content
GenerationConfigJSON []byte // generation_config.json content
SpecialTokensMapJSON []byte // special_tokens_map.json content
ConfigJSON []byte // config.json content
}
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
// This is useful when loading from blob storage where the file content is already in memory.
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
func LoadFromBytes(data []byte) (*Tokenizer, error) {
return loadFromTokenizerJSON(data)
}
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
// This is useful when loading from blob storage where companion config files are also blobs.
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
t, err := loadFromTokenizerJSON(data)
if err != nil {
return nil, err
}
if config == nil {
return t, nil
}
// Apply special token configs from provided data
loadSpecialTokenConfigFromBytes(t, config)
return t, nil
}
// loadFromTokenizerJSON parses tokenizer.json content from bytes.
func loadFromTokenizerJSON(data []byte) (*Tokenizer, error) {
var raw struct {
Model struct {
Type string `json:"type"` // "BPE"
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
}
// Covers SentencePiece and BPE models
if raw.Model.Type != "BPE" {
return nil, fmt.Errorf("unsupported tokenizer type: %s", raw.Model.Type)
}
// Parse merges - can be []string (Llama) or [][]string (GPT-OSS).
var mergesStrings []string
if raw.Model.Merges != nil {
var mergesArrays [][]string
if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
// Try array of arrays format
if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
return nil, fmt.Errorf("failed to parse merges: %w", err)
}
// Convert [][]string to []string
mergesStrings = make([]string, len(mergesArrays))
for i, pair := range mergesArrays {
if len(pair) != 2 {
return nil, fmt.Errorf("failed to parse merges: expected merge pair of length 2, got %d", len(pair))
}
mergesStrings[i] = pair[0] + " " + pair[1]
}
}
}
// Build tokenizer
t := &Tokenizer{
vocab: &Vocabulary{
Values: make([]string, len(raw.Model.Vocab)),
Reverse: raw.Model.Vocab,
Merges: make(map[string]int, len(mergesStrings)),
BOS: -1,
PAD: -1,
},
specialTokens: make(map[string]int32),
}
// Build values array
for token, id := range raw.Model.Vocab {
if int(id) >= len(t.vocab.Values) {
newValues := make([]string, id+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[id] = token
}
// Build merges map
for i, merge := range mergesStrings {
t.vocab.Merges[merge] = i
}
// Add all added_tokens to vocabulary and special tokens map.
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
// they bypass BPE and get their own token ID. The "special" flag just indicates
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
// to treat all added_tokens as special to match HuggingFace behavior.
for _, tok := range raw.AddedTokens {
if int(tok.ID) >= len(t.vocab.Values) {
newValues := make([]string, tok.ID+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[tok.ID] = tok.Content
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
}
// Precompute byte token IDs for <0xNN> fallback
initByteTokens(t)
// Determine tokenizer type
switch {
case detectSentencePiece(raw.Decoder):
t.typ = TokenizerSentencePiece
default:
t.typ = TokenizerBPE
}
// Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
if t.typ == TokenizerBPE {
pattern := extractPretokenizer(raw.PreTokenizer)
if pattern == "" {
pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
}
re, err := regexp.Compile(rewritePatternForRE2(pattern))
if err != nil {
return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
}
t.pretokenizer = re
}
cacheSortedSpecialTokens(t)
return t, nil
}
func cacheSortedSpecialTokens(t *Tokenizer) {
if len(t.specialTokens) == 0 {
t.sortedSpecialTokens = nil
return
}
tokens := make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
t.sortedSpecialTokens = tokens
}
type specialTokenConfigData struct {
tokenizerConfigJSON []byte
generationConfigJSON []byte
specialTokensMapJSON []byte
configJSON []byte
}
func applySpecialTokenConfig(t *Tokenizer, config specialTokenConfigData) {
parseTokenIDs := func(v interface{}) []int32 {
switch val := v.(type) {
case float64:
return []int32{int32(val)}
case []interface{}:
ids := make([]int32, 0, len(val))
for _, id := range val {
if f, ok := id.(float64); ok {
ids = append(ids, int32(f))
}
}
return ids
}
return nil
}
// Priority 1: generation_config.json
if len(config.generationConfigJSON) > 0 {
var genConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.generationConfigJSON, &genConfig); err == nil {
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
// Priority 2: config.json
if len(config.configJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
var modelConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.configJSON, &modelConfig); err == nil {
if len(t.vocab.EOS) == 0 {
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
}
if t.vocab.BOS < 0 {
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
}
// Priority 3: tokenizer_config.json
if len(config.tokenizerConfigJSON) > 0 {
var tokConfig struct {
BOSToken interface{} `json:"bos_token"`
EOSToken interface{} `json:"eos_token"`
PADToken interface{} `json:"pad_token"`
AddBOSToken *bool `json:"add_bos_token"`
AddEOSToken *bool `json:"add_eos_token"`
}
if err := json.Unmarshal(config.tokenizerConfigJSON, &tokConfig); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
if tokConfig.AddBOSToken != nil {
t.vocab.AddBOS = *tokConfig.AddBOSToken
}
if tokConfig.AddEOSToken != nil {
t.vocab.AddEOS = *tokConfig.AddEOSToken
}
}
}
// Priority 4: special_tokens_map.json
if len(config.specialTokensMapJSON) > 0 {
var tokensMap map[string]interface{}
if err := json.Unmarshal(config.specialTokensMapJSON, &tokensMap); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
}
}
}
// extractTokenString extracts the token string from various formats used in HuggingFace configs.
// Tokens can be represented as:
// - string: "token"
// - object: {"content": "token", ...}
func extractTokenString(v interface{}) string {
if v == nil {
return ""
}
// Direct string
if s, ok := v.(string); ok {
return s
}
// Object with content field
if m, ok := v.(map[string]interface{}); ok {
if content, ok := m["content"].(string); ok {
return content
}
}
return ""
}
// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
// - (?!\S) negative lookahead - RE2 doesn't support this
// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
//
// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
func rewritePatternForRE2(pattern string) string {
// Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
// Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
// IMPORTANT: Must be done before the non-optional version to avoid partial replacement
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
// Expand case-insensitive contraction pattern to explicit alternations
// (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
return pattern
}
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
applySpecialTokenConfig(t, specialTokenConfigData{
tokenizerConfigJSON: config.TokenizerConfigJSON,
generationConfigJSON: config.GenerationConfigJSON,
specialTokensMapJSON: config.SpecialTokensMapJSON,
configJSON: config.ConfigJSON,
})
}
// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
// vs GPT-2 byte-level encoding
func detectSentencePiece(data json.RawMessage) bool {
if data == nil {
return false
}
// Check for Sequence decoder with Replace step (SentencePiece style)
var seq struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"String"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(data, &seq); err == nil {
if seq.Type == "Sequence" {
for _, dec := range seq.Decoders {
// Look for Replace decoder that converts ▁ to space
if dec.Type == "Replace" && dec.Pattern.String == "▁" {
return true
}
}
}
}
// Check for direct ByteLevel decoder (GPT-2 style)
var simple struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &simple); err == nil {
if simple.Type == "ByteLevel" {
return false
}
}
return false
}
// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
func initByteTokens(t *Tokenizer) {
for i := range t.vocab.byteTokens {
t.vocab.byteTokens[i] = -1
}
for b := 0; b < 256; b++ {
token := fmt.Sprintf("<0x%02X>", b)
if id, ok := t.vocab.Reverse[token]; ok {
t.vocab.byteTokens[b] = id
}
}
}
// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
func extractPretokenizer(data json.RawMessage) string {
if data == nil {
return ""
}
// Try to parse as a single Split pretokenizer
var single struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}
if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
return single.Pattern.Regex
}
// Try to parse as Sequence of pretokenizers - use first Split pattern
var seq struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
}
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
for _, pt := range seq.Pretokenizers {
if pt.Type == "Split" && pt.Pattern.Regex != "" {
return pt.Pattern.Regex
}
}
}
return ""
}

View File

@@ -1,26 +0,0 @@
//go:build mlx
package tokenizer
import (
"strings"
"testing"
)
func TestLoadFromBytesRejectsWordPiece(t *testing.T) {
data := []byte(`{
"model": {
"type": "WordPiece",
"vocab": {"[UNK]": 0, "hello": 1}
},
"added_tokens": []
}`)
_, err := LoadFromBytes(data)
if err == nil {
t.Fatal("expected WordPiece load to fail")
}
if !strings.Contains(err.Error(), "unsupported tokenizer type: WordPiece") {
t.Fatalf("unexpected error: %v", err)
}
}