mirror of
https://github.com/ollama/ollama.git
synced 2026-02-23 10:45:08 -05:00
Compare commits
3 Commits
v0.17.0-rc
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ac1300805 | ||
|
|
43d9907dd6 | ||
|
|
91dc088e8b |
13
api/types.go
13
api/types.go
@@ -922,6 +922,19 @@ type UserResponse struct {
|
||||
Plan string `json:"plan,omitempty"`
|
||||
}
|
||||
|
||||
type UsageResponse struct {
|
||||
// Start is the time the server started tracking usage (UTC, RFC 3339).
|
||||
Start time.Time `json:"start"`
|
||||
Usage []ModelUsageData `json:"usage"`
|
||||
}
|
||||
|
||||
type ModelUsageData struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
}
|
||||
|
||||
// Tensor describes the metadata for a given tensor.
|
||||
type Tensor struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -41,11 +41,6 @@ type InferenceCompute struct {
|
||||
VRAM string
|
||||
}
|
||||
|
||||
type InferenceInfo struct {
|
||||
Computes []InferenceCompute
|
||||
DefaultContextLength int
|
||||
}
|
||||
|
||||
func New(s *store.Store, devMode bool) *Server {
|
||||
p := resolvePath("ollama")
|
||||
return &Server{store: s, bin: p, dev: devMode}
|
||||
@@ -277,12 +272,9 @@ func openRotatingLog() (io.WriteCloser, error) {
|
||||
|
||||
// Attempt to retrieve inference compute information from the server
|
||||
// log. Set ctx to timeout to control how long to wait for the logs to appear
|
||||
func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
|
||||
info := &InferenceInfo{}
|
||||
computeMarker := regexp.MustCompile(`inference compute.*library=`)
|
||||
defaultCtxMarker := regexp.MustCompile(`vram-based default context`)
|
||||
defaultCtxRegex := regexp.MustCompile(`default_num_ctx=(\d+)`)
|
||||
|
||||
func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
||||
inference := []InferenceCompute{}
|
||||
marker := regexp.MustCompile(`inference compute.*library=`)
|
||||
q := `inference compute.*%s=["]([^"]*)["]`
|
||||
nq := `inference compute.*%s=(\S+)\s`
|
||||
type regex struct {
|
||||
@@ -348,8 +340,8 @@ func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
// Check for inference compute lines
|
||||
if computeMarker.MatchString(line) {
|
||||
match := marker.FindStringSubmatch(line)
|
||||
if len(match) > 0 {
|
||||
ic := InferenceCompute{
|
||||
Library: get("library", line),
|
||||
Variant: get("variant", line),
|
||||
@@ -360,25 +352,12 @@ func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) {
|
||||
}
|
||||
|
||||
slog.Info("Matched", "inference compute", ic)
|
||||
info.Computes = append(info.Computes, ic)
|
||||
continue
|
||||
}
|
||||
// Check for default context length line
|
||||
if defaultCtxMarker.MatchString(line) {
|
||||
match := defaultCtxRegex.FindStringSubmatch(line)
|
||||
if len(match) > 1 {
|
||||
numCtx, err := strconv.Atoi(match[1])
|
||||
if err == nil {
|
||||
info.DefaultContextLength = numCtx
|
||||
slog.Info("Matched default context length", "default_num_ctx", numCtx)
|
||||
}
|
||||
inference = append(inference, ic)
|
||||
} else {
|
||||
// Break out on first non matching line after we start matching
|
||||
if len(inference) > 0 {
|
||||
return inference, nil
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
// If we've found compute info but hit a non-matching line, return what we have
|
||||
// This handles older server versions that don't log the default context line
|
||||
if len(info.Computes) > 0 {
|
||||
return info, nil
|
||||
}
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -205,50 +205,44 @@ func TestServerCmdCloudSettingEnv(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInferenceInfo(t *testing.T) {
|
||||
func TestGetInferenceComputer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
log string
|
||||
expComputes []InferenceCompute
|
||||
expDefaultCtxLen int
|
||||
name string
|
||||
log string
|
||||
exp []InferenceCompute
|
||||
}{
|
||||
{
|
||||
name: "metal",
|
||||
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
||||
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
||||
time=2025-06-30T09:23:07.417-07:00 level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="96.0 GiB" default_num_ctx=262144
|
||||
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
||||
`,
|
||||
expComputes: []InferenceCompute{{
|
||||
exp: []InferenceCompute{{
|
||||
Library: "metal",
|
||||
Driver: "0.0",
|
||||
VRAM: "96.0 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 262144,
|
||||
},
|
||||
{
|
||||
name: "cpu",
|
||||
log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered"
|
||||
time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB"
|
||||
time=2025-07-01T17:59:51.471Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="31.3 GiB" default_num_ctx=32768
|
||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||
`,
|
||||
expComputes: []InferenceCompute{{
|
||||
exp: []InferenceCompute{{
|
||||
Library: "cpu",
|
||||
Driver: "0.0",
|
||||
VRAM: "31.3 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 32768,
|
||||
},
|
||||
{
|
||||
name: "cuda1",
|
||||
log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu"
|
||||
releasing cuda driver library
|
||||
time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB"
|
||||
time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="3.9 GiB" default_num_ctx=4096
|
||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||
`,
|
||||
expComputes: []InferenceCompute{{
|
||||
exp: []InferenceCompute{{
|
||||
Library: "cuda",
|
||||
Variant: "v12",
|
||||
Compute: "6.1",
|
||||
@@ -256,7 +250,6 @@ time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based d
|
||||
Name: "NVIDIA GeForce GT 1030",
|
||||
VRAM: "3.9 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 4096,
|
||||
},
|
||||
{
|
||||
name: "frank",
|
||||
@@ -264,10 +257,9 @@ time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based d
|
||||
releasing cuda driver library
|
||||
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB"
|
||||
time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB"
|
||||
time=2025-07-01T19:36:13.316Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="26.6 GiB" default_num_ctx=32768
|
||||
[GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/"
|
||||
`,
|
||||
expComputes: []InferenceCompute{
|
||||
exp: []InferenceCompute{
|
||||
{
|
||||
Library: "cuda",
|
||||
Variant: "v12",
|
||||
@@ -284,20 +276,6 @@ time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based d
|
||||
VRAM: "16.0 GiB",
|
||||
},
|
||||
},
|
||||
expDefaultCtxLen: 32768,
|
||||
},
|
||||
{
|
||||
name: "missing_default_context",
|
||||
log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler"
|
||||
time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB"
|
||||
time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32
|
||||
`,
|
||||
expComputes: []InferenceCompute{{
|
||||
Library: "metal",
|
||||
Driver: "0.0",
|
||||
VRAM: "96.0 GiB",
|
||||
}},
|
||||
expDefaultCtxLen: 0, // No default context line, should return 0
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -310,21 +288,18 @@ time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not f
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
info, err := GetInferenceInfo(ctx)
|
||||
ics, err := GetInferenceComputer(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get inference info: %v", err)
|
||||
t.Fatalf(" failed to get inference compute: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(info.Computes, tt.expComputes) {
|
||||
t.Fatalf("computes mismatch\ngot:\n%#v\nwant:\n%#v", info.Computes, tt.expComputes)
|
||||
}
|
||||
if info.DefaultContextLength != tt.expDefaultCtxLen {
|
||||
t.Fatalf("default context length mismatch: got %d, want %d", info.DefaultContextLength, tt.expDefaultCtxLen)
|
||||
if !reflect.DeepEqual(ics, tt.exp) {
|
||||
t.Fatalf("got:\n%#v\nwant:\n%#v", ics, tt.exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInferenceInfoTimeout(t *testing.T) {
|
||||
func TestGetInferenceComputerTimeout(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
tmpDir := t.TempDir()
|
||||
@@ -333,7 +308,7 @@ func TestGetInferenceInfoTimeout(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
|
||||
}
|
||||
_, err = GetInferenceInfo(ctx)
|
||||
_, err = GetInferenceComputer(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected timeout")
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
// currentSchemaVersion defines the current database schema version.
|
||||
// Increment this when making schema changes that require migrations.
|
||||
const currentSchemaVersion = 14
|
||||
const currentSchemaVersion = 13
|
||||
|
||||
// database wraps the SQLite connection.
|
||||
// SQLite handles its own locking for concurrent access:
|
||||
@@ -73,7 +73,7 @@ func (db *database) init() error {
|
||||
agent BOOLEAN NOT NULL DEFAULT 0,
|
||||
tools BOOLEAN NOT NULL DEFAULT 0,
|
||||
working_dir TEXT NOT NULL DEFAULT '',
|
||||
context_length INTEGER NOT NULL DEFAULT 0,
|
||||
context_length INTEGER NOT NULL DEFAULT 4096,
|
||||
window_width INTEGER NOT NULL DEFAULT 0,
|
||||
window_height INTEGER NOT NULL DEFAULT 0,
|
||||
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||
@@ -251,12 +251,6 @@ func (db *database) migrate() error {
|
||||
return fmt.Errorf("migrate v12 to v13: %w", err)
|
||||
}
|
||||
version = 13
|
||||
case 13:
|
||||
// change default context_length from 4096 to 0 (VRAM-based tiered defaults)
|
||||
if err := db.migrateV13ToV14(); err != nil {
|
||||
return fmt.Errorf("migrate v13 to v14: %w", err)
|
||||
}
|
||||
version = 14
|
||||
default:
|
||||
// If we have a version we don't recognize, just set it to current
|
||||
// This might happen during development
|
||||
@@ -480,22 +474,6 @@ func (db *database) migrateV12ToV13() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateV13ToV14 changes the default context_length from 4096 to 0.
|
||||
// When context_length is 0, the ollama server uses VRAM-based tiered defaults.
|
||||
func (db *database) migrateV13ToV14() error {
|
||||
_, err := db.conn.Exec(`UPDATE settings SET context_length = 0 WHERE context_length = 4096`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update context_length default: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 14`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update schema version: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||
func (db *database) cleanupOrphanedData() error {
|
||||
_, err := db.conn.Exec(`
|
||||
|
||||
@@ -98,43 +98,6 @@ func TestSchemaMigrations(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrationV13ToV14ContextLength(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
|
||||
db, err := newDatabase(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.conn.Exec("UPDATE settings SET context_length = 4096, schema_version = 13")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to seed v13 settings row: %v", err)
|
||||
}
|
||||
|
||||
if err := db.migrate(); err != nil {
|
||||
t.Fatalf("migration from v13 to v14 failed: %v", err)
|
||||
}
|
||||
|
||||
var contextLength int
|
||||
if err := db.conn.QueryRow("SELECT context_length FROM settings").Scan(&contextLength); err != nil {
|
||||
t.Fatalf("failed to read context_length: %v", err)
|
||||
}
|
||||
|
||||
if contextLength != 0 {
|
||||
t.Fatalf("expected context_length to migrate to 0, got %d", contextLength)
|
||||
}
|
||||
|
||||
version, err := db.getSchemaVersion()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get schema version: %v", err)
|
||||
}
|
||||
if version != currentSchemaVersion {
|
||||
t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatDeletionWithCascade(t *testing.T) {
|
||||
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
2
app/store/testdata/schema.sql
vendored
2
app/store/testdata/schema.sql
vendored
@@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS settings (
|
||||
agent BOOLEAN NOT NULL DEFAULT 0,
|
||||
tools BOOLEAN NOT NULL DEFAULT 0,
|
||||
working_dir TEXT NOT NULL DEFAULT '',
|
||||
context_length INTEGER NOT NULL DEFAULT 0,
|
||||
context_length INTEGER NOT NULL DEFAULT 4096,
|
||||
window_width INTEGER NOT NULL DEFAULT 0,
|
||||
window_height INTEGER NOT NULL DEFAULT 0,
|
||||
config_migrated BOOLEAN NOT NULL DEFAULT 0,
|
||||
|
||||
@@ -289,12 +289,10 @@ export class InferenceCompute {
|
||||
}
|
||||
export class InferenceComputeResponse {
|
||||
inferenceComputes: InferenceCompute[];
|
||||
defaultContextLength: number;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute);
|
||||
this.defaultContextLength = source["defaultContextLength"];
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
ChatEvent,
|
||||
DownloadEvent,
|
||||
ErrorEvent,
|
||||
InferenceCompute,
|
||||
InferenceComputeResponse,
|
||||
ModelCapabilitiesResponse,
|
||||
Model,
|
||||
@@ -406,7 +407,7 @@ export async function* pullModel(
|
||||
}
|
||||
}
|
||||
|
||||
export async function getInferenceCompute(): Promise<InferenceComputeResponse> {
|
||||
export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/inference-compute`);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
@@ -415,7 +416,8 @@ export async function getInferenceCompute(): Promise<InferenceComputeResponse> {
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return new InferenceComputeResponse(data);
|
||||
const inferenceComputeResponse = new InferenceComputeResponse(data);
|
||||
return inferenceComputeResponse.inferenceComputes || [];
|
||||
}
|
||||
|
||||
export async function fetchHealth(): Promise<boolean> {
|
||||
|
||||
@@ -26,7 +26,6 @@ import {
|
||||
type CloudStatusResponse,
|
||||
updateCloudSetting,
|
||||
updateSettings,
|
||||
getInferenceCompute,
|
||||
} from "@/api";
|
||||
|
||||
function AnimatedDots() {
|
||||
@@ -78,13 +77,6 @@ export default function Settings() {
|
||||
|
||||
const settings = settingsData?.settings || null;
|
||||
|
||||
const { data: inferenceComputeResponse } = useQuery({
|
||||
queryKey: ["inferenceCompute"],
|
||||
queryFn: getInferenceCompute,
|
||||
});
|
||||
|
||||
const defaultContextLength = inferenceComputeResponse?.defaultContextLength;
|
||||
|
||||
const updateSettingsMutation = useMutation({
|
||||
mutationFn: updateSettings,
|
||||
onSuccess: () => {
|
||||
@@ -212,7 +204,7 @@ export default function Settings() {
|
||||
Models: "",
|
||||
Agent: false,
|
||||
Tools: false,
|
||||
ContextLength: 0,
|
||||
ContextLength: 4096,
|
||||
});
|
||||
updateSettingsMutation.mutate(defaultSettings);
|
||||
}
|
||||
@@ -515,11 +507,13 @@ export default function Settings() {
|
||||
</Description>
|
||||
<div className="mt-3">
|
||||
<Slider
|
||||
value={settings.ContextLength || defaultContextLength || 0}
|
||||
value={(() => {
|
||||
// Otherwise use the settings value
|
||||
return settings.ContextLength || 4096;
|
||||
})()}
|
||||
onChange={(value) => {
|
||||
handleChange("ContextLength", value);
|
||||
}}
|
||||
disabled={!defaultContextLength}
|
||||
options={[
|
||||
{ value: 4096, label: "4k" },
|
||||
{ value: 8192, label: "8k" },
|
||||
|
||||
@@ -6,11 +6,10 @@ export interface SliderProps {
|
||||
value?: number;
|
||||
onChange?: (value: number) => void;
|
||||
className?: string;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
({ label, options, value = 0, onChange, disabled = false }, ref) => {
|
||||
({ label, options, value = 0, onChange }, ref) => {
|
||||
const [selectedValue, setSelectedValue] = React.useState(value);
|
||||
const [isDragging, setIsDragging] = React.useState(false);
|
||||
const containerRef = React.useRef<HTMLDivElement>(null);
|
||||
@@ -21,7 +20,6 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
}, [value]);
|
||||
|
||||
const handleClick = (optionValue: number) => {
|
||||
if (disabled) return;
|
||||
setSelectedValue(optionValue);
|
||||
onChange?.(optionValue);
|
||||
};
|
||||
@@ -41,7 +39,6 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
};
|
||||
|
||||
const handleMouseDown = (e: React.MouseEvent) => {
|
||||
if (disabled) return;
|
||||
setIsDragging(true);
|
||||
e.preventDefault();
|
||||
};
|
||||
@@ -80,7 +77,7 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={`space-y-2 ${disabled ? "opacity-50" : ""}`} ref={ref}>
|
||||
<div className="space-y-2" ref={ref}>
|
||||
{label && <label className="text-sm font-medium">{label}</label>}
|
||||
<div className="relative">
|
||||
<div className="absolute top-[9px] left-2 right-2 h-1 bg-neutral-200 dark:bg-neutral-700 pointer-events-none rounded-full" />
|
||||
@@ -91,11 +88,10 @@ const Slider = React.forwardRef<HTMLDivElement, SliderProps>(
|
||||
<button
|
||||
onClick={() => handleClick(option.value)}
|
||||
onMouseDown={handleMouseDown}
|
||||
disabled={disabled}
|
||||
className={`relative px-3 py-6 -mx-3 -my-6 z-10 ${disabled ? "cursor-not-allowed" : "cursor-pointer"}`}
|
||||
className="relative px-3 py-6 -mx-3 -my-6 z-10 cursor-pointer"
|
||||
>
|
||||
<div className="relative w-5 h-5 flex items-center justify-center">
|
||||
{selectedValue === option.value && !disabled && (
|
||||
{selectedValue === option.value && (
|
||||
<div className="w-4 h-4 bg-white dark:bg-white border border-neutral-400 dark:border-neutral-500 rounded-full cursor-grab active:cursor-grabbing" />
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -28,14 +28,12 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
currentChatId && currentChatId !== "new" ? currentChatId : "",
|
||||
);
|
||||
|
||||
const { data: inferenceComputeResponse } = useQuery({
|
||||
queryKey: ["inferenceCompute"],
|
||||
const { data: inferenceComputes = [] } = useQuery({
|
||||
queryKey: ["inference-compute"],
|
||||
queryFn: getInferenceCompute,
|
||||
enabled: !settings.selectedModel, // Only fetch if no model is selected
|
||||
});
|
||||
|
||||
const inferenceComputes = inferenceComputeResponse?.inferenceComputes || [];
|
||||
|
||||
const totalVRAM = useMemo(
|
||||
() => getTotalVRAM(inferenceComputes),
|
||||
[inferenceComputes],
|
||||
|
||||
@@ -45,8 +45,7 @@ type InferenceCompute struct {
|
||||
}
|
||||
|
||||
type InferenceComputeResponse struct {
|
||||
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
|
||||
DefaultContextLength int `json:"defaultContextLength"`
|
||||
InferenceComputes []InferenceCompute `json:"inferenceComputes"`
|
||||
}
|
||||
|
||||
type ModelCapabilitiesResponse struct {
|
||||
|
||||
18
app/ui/ui.go
18
app/ui/ui.go
@@ -1420,6 +1420,11 @@ func (s *Server) getSettings(w http.ResponseWriter, r *http.Request) error {
|
||||
settings.Models = envconfig.Models()
|
||||
}
|
||||
|
||||
// set default context length if not set
|
||||
if settings.ContextLength == 0 {
|
||||
settings.ContextLength = 4096
|
||||
}
|
||||
|
||||
// Include current runtime settings
|
||||
settings.Agent = s.Agent
|
||||
settings.Tools = s.Tools
|
||||
@@ -1495,14 +1500,14 @@ func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
|
||||
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
info, err := server.GetInferenceInfo(ctx)
|
||||
serverInferenceComputes, err := server.GetInferenceComputer(ctx)
|
||||
if err != nil {
|
||||
s.log().Error("failed to get inference info", "error", err)
|
||||
return fmt.Errorf("failed to get inference info: %w", err)
|
||||
s.log().Error("failed to get inference compute", "error", err)
|
||||
return fmt.Errorf("failed to get inference compute: %w", err)
|
||||
}
|
||||
|
||||
inferenceComputes := make([]responses.InferenceCompute, len(info.Computes))
|
||||
for i, ic := range info.Computes {
|
||||
inferenceComputes := make([]responses.InferenceCompute, len(serverInferenceComputes))
|
||||
for i, ic := range serverInferenceComputes {
|
||||
inferenceComputes[i] = responses.InferenceCompute{
|
||||
Library: ic.Library,
|
||||
Variant: ic.Variant,
|
||||
@@ -1514,8 +1519,7 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
|
||||
}
|
||||
|
||||
response := responses.InferenceComputeResponse{
|
||||
InferenceComputes: inferenceComputes,
|
||||
DefaultContextLength: info.DefaultContextLength,
|
||||
InferenceComputes: inferenceComputes,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@@ -1956,10 +1956,6 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
}
|
||||
|
||||
launchIntegration := func(name string) bool {
|
||||
if err := config.EnsureInstalled(name); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
return true
|
||||
}
|
||||
// If not configured or model no longer exists, prompt for model selection
|
||||
configuredModel := config.IntegrationModel(name)
|
||||
if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
@@ -33,10 +32,6 @@ func (c *Codex) Run(model string, args []string) error {
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
|
||||
"OPENAI_API_KEY=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
|
||||
@@ -15,9 +15,8 @@ import (
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
Onboarded bool `json:"onboarded,omitempty"`
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
@@ -140,54 +139,34 @@ func SaveIntegration(appName string, models []string) error {
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
var onboarded bool
|
||||
if existing != nil {
|
||||
if existing != nil && existing.Aliases != nil {
|
||||
aliases = existing.Aliases
|
||||
onboarded = existing.Onboarded
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
Onboarded: onboarded,
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// integrationOnboarded marks an integration as onboarded in ollama's config.
|
||||
func integrationOnboarded(appName string) error {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
existing.Onboarded = true
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
|
||||
func IntegrationModel(appName string) string {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
ic, err := loadIntegration(appName)
|
||||
if err != nil || len(ic.Models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return integrationConfig.Models[0]
|
||||
return ic.Models[0]
|
||||
}
|
||||
|
||||
// IntegrationModels returns all configured models for an integration, or nil.
|
||||
func IntegrationModels(appName string) []string {
|
||||
integrationConfig, err := loadIntegration(appName)
|
||||
if err != nil || len(integrationConfig.Models) == 0 {
|
||||
ic, err := loadIntegration(appName)
|
||||
if err != nil || len(ic.Models) == 0 {
|
||||
return nil
|
||||
}
|
||||
return integrationConfig.Models
|
||||
return ic.Models
|
||||
}
|
||||
|
||||
// LastModel returns the last model that was run, or empty string if none.
|
||||
@@ -255,12 +234,12 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
integrationConfig, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||
ic, ok := cfg.Integrations[strings.ToLower(appName)]
|
||||
if !ok {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
return integrationConfig, nil
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
@@ -293,8 +272,8 @@ func listIntegrations() ([]integration, error) {
|
||||
}
|
||||
|
||||
result := make([]integration, 0, len(cfg.Integrations))
|
||||
for _, integrationConfig := range cfg.Integrations {
|
||||
result = append(result, *integrationConfig)
|
||||
for _, ic := range cfg.Integrations {
|
||||
result = append(result, *ic)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
@@ -228,31 +228,6 @@ func IsIntegrationInstalled(name string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// AutoInstallable returns true if the integration can be automatically
|
||||
// installed when not found (e.g. via npm).
|
||||
func AutoInstallable(name string) bool {
|
||||
switch strings.ToLower(name) {
|
||||
case "openclaw", "clawdbot", "moltbot":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureInstalled checks if an auto-installable integration is present and
|
||||
// offers to install it if missing. Returns nil for non-auto-installable
|
||||
// integrations or when the binary is already on PATH.
|
||||
func EnsureInstalled(name string) error {
|
||||
if !AutoInstallable(name) {
|
||||
return nil
|
||||
}
|
||||
if IsIntegrationInstalled(name) {
|
||||
return nil
|
||||
}
|
||||
_, err := ensureOpenclawInstalled()
|
||||
return err
|
||||
}
|
||||
|
||||
// IsEditorIntegration returns true if the named integration uses multi-model
|
||||
// selection (implements the Editor interface).
|
||||
func IsEditorIntegration(name string) bool {
|
||||
@@ -951,10 +926,6 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
if err := EnsureInstalled(name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if modelFlag != "" && IsCloudModelDisabled(cmd.Context(), modelFlag) {
|
||||
modelFlag = ""
|
||||
}
|
||||
|
||||
@@ -1,287 +1,81 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const defaultGatewayPort = 18789
|
||||
|
||||
// Bound model capability probing so launch/config cannot hang on slow/unreachable API calls.
|
||||
var openclawModelShowTimeout = 5 * time.Second
|
||||
|
||||
type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin, err := ensureOpenclawInstalled()
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
bin = "clawdbot"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
|
||||
}
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
var err error
|
||||
models, err = resolveEditorModels("openclaw", models, func() ([]string, error) {
|
||||
return selectModels(context.Background(), "openclaw", "")
|
||||
})
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
firstLaunch := true
|
||||
if integrationConfig, err := loadIntegration("openclaw"); err == nil {
|
||||
firstLaunch = !integrationConfig.Onboarded
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "\n%sSecurity%s\n\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, " OpenClaw can read files and run actions when tools are enabled.\n")
|
||||
fmt.Fprintf(os.Stderr, " A bad prompt can trick it into doing unsafe things.\n\n")
|
||||
fmt.Fprintf(os.Stderr, "%s Learn more: https://docs.openclaw.ai/gateway/security%s\n\n", ansiGray, ansiReset)
|
||||
|
||||
ok, err := confirmPrompt("I understand the risks. Continue?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
if !c.onboarded() {
|
||||
fmt.Fprintf(os.Stderr, "\n%sSetting up OpenClaw with Ollama...%s\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s Model: %s%s\n\n", ansiGray, model, ansiReset)
|
||||
|
||||
// Onboarding not completed: run it (model already set via Edit)
|
||||
// Use "ollama" as gateway token for simple local access
|
||||
cmd := exec.Command(bin, "onboard",
|
||||
"--non-interactive",
|
||||
"--accept-risk",
|
||||
"--auth-choice", "skip",
|
||||
"--gateway-token", "ollama",
|
||||
"--install-daemon",
|
||||
"--skip-channels",
|
||||
"--skip-skills",
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(fmt.Errorf("openclaw onboarding failed: %w\n\nTry running: openclaw onboard", err))
|
||||
}
|
||||
|
||||
patchDeviceScopes()
|
||||
|
||||
// Onboarding overwrites openclaw.json, so re-apply the model config
|
||||
// that Edit() wrote before Run() was called.
|
||||
if err := c.Edit([]string{model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not re-apply model config: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
if strings.HasSuffix(model, ":cloud") || strings.HasSuffix(model, "-cloud") {
|
||||
if ensureWebSearchPlugin() {
|
||||
registerWebSearchPlugin()
|
||||
}
|
||||
}
|
||||
// Onboarding completed: run gateway
|
||||
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "\n%sPreparing your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\n%sStarting your assistant — this may take a moment...%s\n\n", ansiGray, ansiReset)
|
||||
}
|
||||
// Capture output to detect "already running" message
|
||||
var outputBuf bytes.Buffer
|
||||
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
|
||||
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
|
||||
|
||||
// When extra args are passed through, run exactly what the user asked for
|
||||
// after setup and skip the built-in gateway+TUI convenience flow.
|
||||
if len(args) > 0 {
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.Env = openclawEnv()
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
err = cmd.Run()
|
||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
||||
return nil
|
||||
}
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
addr := fmt.Sprintf("localhost:%d", port)
|
||||
|
||||
// If the gateway is already running (e.g. via the daemon), restart it
|
||||
// so it picks up any config changes from Edit() above (model, provider, etc.).
|
||||
if portOpen(addr) {
|
||||
restart := exec.Command(bin, "daemon", "restart")
|
||||
restart.Env = openclawEnv()
|
||||
if err := restart.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: daemon restart failed: %v%s\n", ansiYellow, err, ansiReset)
|
||||
}
|
||||
if !waitForPort(addr, 10*time.Second) {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: gateway did not come back after restart%s\n", ansiYellow, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// If the gateway isn't running, start it as a background child process.
|
||||
if !portOpen(addr) {
|
||||
gw := exec.Command(bin, "gateway", "run", "--force")
|
||||
gw.Env = openclawEnv()
|
||||
if err := gw.Start(); err != nil {
|
||||
return windowsHint(fmt.Errorf("failed to start gateway: %w", err))
|
||||
}
|
||||
defer func() {
|
||||
if gw.Process != nil {
|
||||
_ = gw.Process.Kill()
|
||||
_ = gw.Wait()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sStarting gateway...%s\n", ansiGray, ansiReset)
|
||||
if !waitForPort(addr, 30*time.Second) {
|
||||
return windowsHint(fmt.Errorf("gateway did not start on %s", addr))
|
||||
}
|
||||
|
||||
printOpenclawReady(bin, token, port, firstLaunch)
|
||||
|
||||
tuiArgs := []string{"tui"}
|
||||
if firstLaunch {
|
||||
tuiArgs = append(tuiArgs, "--message", "Wake up, my friend!")
|
||||
}
|
||||
tui := exec.Command(bin, tuiArgs...)
|
||||
tui.Env = openclawEnv()
|
||||
tui.Stdin = os.Stdin
|
||||
tui.Stdout = os.Stdout
|
||||
tui.Stderr = os.Stderr
|
||||
if err := tui.Run(); err != nil {
|
||||
return windowsHint(err)
|
||||
}
|
||||
|
||||
if firstLaunch {
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
return fmt.Errorf("failed to save onboarding state: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// gatewayInfo reads the gateway auth token and port from the OpenClaw config.
|
||||
func (c *Openclaw) gatewayInfo() (token string, port int) {
|
||||
port = defaultGatewayPort
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", port
|
||||
}
|
||||
|
||||
for _, path := range []string{
|
||||
filepath.Join(home, ".openclaw", "openclaw.json"),
|
||||
filepath.Join(home, ".clawdbot", "clawdbot.json"),
|
||||
} {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var config map[string]any
|
||||
if json.Unmarshal(data, &config) != nil {
|
||||
continue
|
||||
}
|
||||
gw, _ := config["gateway"].(map[string]any)
|
||||
if p, ok := gw["port"].(float64); ok && p > 0 {
|
||||
port = int(p)
|
||||
}
|
||||
auth, _ := gw["auth"].(map[string]any)
|
||||
if t, _ := auth["token"].(string); t != "" {
|
||||
token = t
|
||||
}
|
||||
return token, port
|
||||
}
|
||||
return "", port
|
||||
}
|
||||
|
||||
func printOpenclawReady(bin, token string, port int, firstLaunch bool) {
|
||||
u := fmt.Sprintf("http://localhost:%d", port)
|
||||
if token != "" {
|
||||
u += "/#token=" + url.QueryEscape(token)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%s✓ OpenClaw is running%s\n\n", ansiGreen, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, " Open the Web UI:\n")
|
||||
fmt.Fprintf(os.Stderr, " %s\n\n", hyperlink(u, u))
|
||||
|
||||
if firstLaunch {
|
||||
fmt.Fprintf(os.Stderr, "%s Quick start:%s\n", ansiBold, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s /help see all commands%s\n", ansiGray, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s %s configure --section channels connect WhatsApp, Telegram, etc.%s\n", ansiGray, bin, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s %s skills browse and install skills%s\n\n", ansiGray, bin, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s The OpenClaw gateway is running in the background.%s\n", ansiYellow, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%s Stop it with: %s gateway stop%s\n\n", ansiYellow, bin, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%sTip: connect WhatsApp, Telegram, and more with: %s configure --section channels%s\n", ansiGray, bin, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
// openclawEnv returns the current environment with provider API keys cleared
|
||||
// so openclaw only uses the Ollama gateway, not keys from the user's shell.
|
||||
func openclawEnv() []string {
|
||||
clear := map[string]bool{
|
||||
"ANTHROPIC_API_KEY": true,
|
||||
"ANTHROPIC_OAUTH_TOKEN": true,
|
||||
"OPENAI_API_KEY": true,
|
||||
"GEMINI_API_KEY": true,
|
||||
"MISTRAL_API_KEY": true,
|
||||
"GROQ_API_KEY": true,
|
||||
"XAI_API_KEY": true,
|
||||
"OPENROUTER_API_KEY": true,
|
||||
}
|
||||
var env []string
|
||||
for _, e := range os.Environ() {
|
||||
key, _, _ := strings.Cut(e, "=")
|
||||
if !clear[key] {
|
||||
env = append(env, e)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// portOpen checks if a TCP port is currently accepting connections.
|
||||
func portOpen(addr string) bool {
|
||||
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func waitForPort(addr string, timeout time.Duration) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
return true
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func windowsHint(err error) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("%w\n\n"+
|
||||
"OpenClaw runs best on WSL2.\n"+
|
||||
"Quick setup: wsl --install\n"+
|
||||
"Guide: https://docs.openclaw.ai/windows", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// onboarded checks if OpenClaw onboarding wizard was completed
|
||||
@@ -313,144 +107,6 @@ func (c *Openclaw) onboarded() bool {
|
||||
return lastRunAt != ""
|
||||
}
|
||||
|
||||
// patchDeviceScopes upgrades the local CLI device's paired scopes to include
|
||||
// operator.admin. Only patches the local device, not remote ones.
|
||||
// Best-effort: silently returns on any error.
|
||||
func patchDeviceScopes() {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := readLocalDeviceID(home)
|
||||
if deviceID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
path := filepath.Join(home, ".openclaw", "devices", "paired.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var devices map[string]map[string]any
|
||||
if err := json.Unmarshal(data, &devices); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
dev, ok := devices[deviceID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
required := []string{
|
||||
"operator.read",
|
||||
"operator.admin",
|
||||
"operator.approvals",
|
||||
"operator.pairing",
|
||||
}
|
||||
|
||||
changed := patchScopes(dev, "scopes", required)
|
||||
if tokens, ok := dev["tokens"].(map[string]any); ok {
|
||||
for _, tok := range tokens {
|
||||
if tokenMap, ok := tok.(map[string]any); ok {
|
||||
if patchScopes(tokenMap, "scopes", required) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
|
||||
out, err := json.MarshalIndent(devices, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(path, out, 0o600)
|
||||
}
|
||||
|
||||
// readLocalDeviceID reads the local device ID from openclaw's identity file.
|
||||
func readLocalDeviceID(home string) string {
|
||||
data, err := os.ReadFile(filepath.Join(home, ".openclaw", "identity", "device-auth.json"))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var auth map[string]any
|
||||
if err := json.Unmarshal(data, &auth); err != nil {
|
||||
return ""
|
||||
}
|
||||
id, _ := auth["deviceId"].(string)
|
||||
return id
|
||||
}
|
||||
|
||||
// patchScopes ensures obj[key] contains all required scopes. Returns true if
|
||||
// any scopes were added.
|
||||
func patchScopes(obj map[string]any, key string, required []string) bool {
|
||||
existing, _ := obj[key].([]any)
|
||||
have := make(map[string]bool, len(existing))
|
||||
for _, s := range existing {
|
||||
if str, ok := s.(string); ok {
|
||||
have[str] = true
|
||||
}
|
||||
}
|
||||
added := false
|
||||
for _, s := range required {
|
||||
if !have[s] {
|
||||
existing = append(existing, s)
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if added {
|
||||
obj[key] = existing
|
||||
}
|
||||
return added
|
||||
}
|
||||
|
||||
func ensureOpenclawInstalled() (string, error) {
|
||||
if _, err := exec.LookPath("openclaw"); err == nil {
|
||||
return "openclaw", nil
|
||||
}
|
||||
if _, err := exec.LookPath("clawdbot"); err == nil {
|
||||
return "clawdbot", nil
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("npm"); err != nil {
|
||||
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
|
||||
"Install Node.js first:\n" +
|
||||
" https://nodejs.org/\n\n" +
|
||||
"Then rerun:\n" +
|
||||
" ollama launch\n" +
|
||||
"and select OpenClaw")
|
||||
}
|
||||
|
||||
ok, err := confirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
return "", fmt.Errorf("openclaw installation cancelled")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling OpenClaw...\n")
|
||||
cmd := exec.Command("npm", "install", "-g", "openclaw@latest")
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("failed to install openclaw: %w", err)
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("openclaw"); err != nil {
|
||||
return "", fmt.Errorf("openclaw was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
return "openclaw", nil
|
||||
}
|
||||
|
||||
func (c *Openclaw) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
@@ -505,7 +161,8 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
ollama["api"] = "ollama"
|
||||
// TODO(parthsareen): potentially move to responses
|
||||
ollama["api"] = "openai-completions"
|
||||
|
||||
// Build map of existing models to preserve user customizations
|
||||
existingModels, _ := ollama["models"].([]any)
|
||||
@@ -518,13 +175,25 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
|
||||
var newModels []any
|
||||
for _, m := range models {
|
||||
entry, _ := openclawModelConfig(context.Background(), client, m)
|
||||
for _, model := range models {
|
||||
entry := map[string]any{
|
||||
"id": model,
|
||||
"name": model,
|
||||
"reasoning": false,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
// TODO(parthsareen): get these values from API
|
||||
"contextWindow": 131072,
|
||||
"maxTokens": 16384,
|
||||
}
|
||||
// Merge existing fields (user customizations)
|
||||
if existing, ok := existingByID[m]; ok {
|
||||
if existing, ok := existingByID[model]; ok {
|
||||
for k, v := range existing {
|
||||
if _, isNew := entry[k]; !isNew {
|
||||
entry[k] = v
|
||||
@@ -561,237 +230,7 @@ func (c *Openclaw) Edit(models []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeWithBackup(configPath, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clear any per-session model overrides so the new primary takes effect
|
||||
// immediately rather than being shadowed by a cached modelOverride.
|
||||
clearSessionModelOverride(models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearSessionModelOverride removes per-session model overrides from the main
|
||||
// agent session so the global primary model takes effect on the next TUI launch.
|
||||
func clearSessionModelOverride(primary string) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
path := filepath.Join(home, ".openclaw", "agents", "main", "sessions", "sessions.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var sessions map[string]map[string]any
|
||||
if json.Unmarshal(data, &sessions) != nil {
|
||||
return
|
||||
}
|
||||
changed := false
|
||||
for _, sess := range sessions {
|
||||
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
||||
delete(sess, "modelOverride")
|
||||
delete(sess, "providerOverride")
|
||||
sess["model"] = primary
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
out, err := json.MarshalIndent(sessions, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(path, out, 0o600)
|
||||
}
|
||||
|
||||
const webSearchNpmPackage = "@ollama/openclaw-web-search"
|
||||
|
||||
// ensureWebSearchPlugin installs the openclaw-web-search extension into the OpenClaw
|
||||
// extensions directory if it isn't already present. Returns true if the extension
|
||||
// is available (either already installed or just installed).
|
||||
func ensureWebSearchPlugin() bool {
|
||||
extDir := openclawExtensionsDir()
|
||||
if extDir == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
pluginDir := filepath.Join(extDir, "openclaw-web-search")
|
||||
if _, err := os.Stat(filepath.Join(pluginDir, "index.ts")); err == nil {
|
||||
return true // already installed
|
||||
}
|
||||
|
||||
npmBin, err := exec.LookPath("npm")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
|
||||
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
|
||||
out, err := pack.Output()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||
return false
|
||||
}
|
||||
|
||||
tgzName := strings.TrimSpace(string(out))
|
||||
tgzPath := filepath.Join(pluginDir, tgzName)
|
||||
defer os.Remove(tgzPath)
|
||||
|
||||
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
|
||||
if err := tar.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
|
||||
return false
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%s ✓ Installed web search plugin%s\n", ansiGreen, ansiReset)
|
||||
return true
|
||||
}
|
||||
|
||||
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
|
||||
// config so the gateway activates it on next start. Best-effort; silently returns
|
||||
// on any error.
|
||||
func registerWebSearchPlugin() {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var config map[string]any
|
||||
if json.Unmarshal(data, &config) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
plugins, _ := config["plugins"].(map[string]any)
|
||||
if plugins == nil {
|
||||
plugins = make(map[string]any)
|
||||
}
|
||||
entries, _ := plugins["entries"].(map[string]any)
|
||||
if entries == nil {
|
||||
entries = make(map[string]any)
|
||||
}
|
||||
if _, ok := entries["openclaw-web-search"]; ok {
|
||||
return // already registered
|
||||
}
|
||||
entries["openclaw-web-search"] = map[string]any{"enabled": true}
|
||||
plugins["entries"] = entries
|
||||
config["plugins"] = plugins
|
||||
|
||||
// Disable the built-in web search since our plugin replaces it.
|
||||
tools, _ := config["tools"].(map[string]any)
|
||||
if tools == nil {
|
||||
tools = make(map[string]any)
|
||||
}
|
||||
web, _ := tools["web"].(map[string]any)
|
||||
if web == nil {
|
||||
web = make(map[string]any)
|
||||
}
|
||||
web["search"] = map[string]any{"enabled": false}
|
||||
tools["web"] = web
|
||||
config["tools"] = tools
|
||||
|
||||
out, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(configPath, out, 0o600)
|
||||
}
|
||||
|
||||
// openclawExtensionsDir resolves the extensions directory inside the openclaw
|
||||
// npm package. Returns "" if the binary or path cannot be resolved.
|
||||
func openclawExtensionsDir() string {
|
||||
bin, err := exec.LookPath("openclaw")
|
||||
if err != nil {
|
||||
bin, err = exec.LookPath("clawdbot")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
binPath, err := filepath.EvalSymlinks(bin)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// The binary symlink resolves to <pkg>/openclaw.mjs (package root).
|
||||
// Extensions live at <pkg>/extensions/.
|
||||
pkgDir := filepath.Dir(binPath)
|
||||
extDir := filepath.Join(pkgDir, "extensions")
|
||||
if info, err := os.Stat(extDir); err == nil && info.IsDir() {
|
||||
return extDir
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// openclawModelConfig builds an OpenClaw model config entry with capability detection.
|
||||
// The second return value indicates whether the model is a cloud (remote) model.
|
||||
func openclawModelConfig(ctx context.Context, client *api.Client, modelID string) (map[string]any, bool) {
|
||||
entry := map[string]any{
|
||||
"id": modelID,
|
||||
"name": modelID,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return entry, false
|
||||
}
|
||||
|
||||
showCtx := ctx
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
showCtx, cancel = context.WithTimeout(ctx, openclawModelShowTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
resp, err := client.Show(showCtx, &api.ShowRequest{Model: modelID})
|
||||
if err != nil {
|
||||
return entry, false
|
||||
}
|
||||
|
||||
// Set input types based on vision capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
|
||||
entry["input"] = []any{"text", "image"}
|
||||
}
|
||||
|
||||
// Set reasoning based on thinking capability
|
||||
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
entry["reasoning"] = true
|
||||
}
|
||||
|
||||
// Cloud models: use hardcoded limits for context/output tokens.
|
||||
// Capability detection above still applies (vision, thinking).
|
||||
if resp.RemoteModel != "" {
|
||||
if l, ok := lookupCloudModelLimit(modelID); ok {
|
||||
entry["contextWindow"] = l.Context
|
||||
entry["maxTokens"] = l.Output
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
// Extract context window from ModelInfo (local models only)
|
||||
for key, val := range resp.ModelInfo {
|
||||
if strings.HasSuffix(key, ".context_length") {
|
||||
if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
|
||||
entry["contextWindow"] = int(ctxLen)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return entry, false
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Openclaw) Models() []string {
|
||||
|
||||
@@ -1,21 +1,11 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestOpenclawIntegration(t *testing.T) {
|
||||
@@ -36,124 +26,6 @@ func TestOpenclawIntegration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawRunPassthroughArgs(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses a POSIX shell test binary")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bin := filepath.Join(tmpDir, "openclaw")
|
||||
if err := os.WriteFile(bin, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$HOME/invocations.log\"\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Openclaw{}
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--someflag"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(tmpDir, "invocations.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
|
||||
if len(lines) != 1 {
|
||||
t.Fatalf("expected exactly 1 invocation, got %d: %v", len(lines), lines)
|
||||
}
|
||||
if lines[0] != "gateway --someflag" {
|
||||
t.Fatalf("invocation = %q, want %q", lines[0], "gateway --someflag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawRunFirstLaunchPersistence(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses a POSIX shell test binary")
|
||||
}
|
||||
|
||||
oldHook := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldHook }()
|
||||
|
||||
t.Run("success persists onboarding flag", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Mark OpenClaw onboarding complete so Run takes passthrough path directly.
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Openclaw{}
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil {
|
||||
t.Fatalf("loadIntegration() error = %v", err)
|
||||
}
|
||||
if !integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to be persisted after successful run")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("failure does not persist onboarding flag", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"wizard": {"lastRunAt": "2026-01-01T00:00:00Z"}
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "openclaw"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := &Openclaw{}
|
||||
if err := c.Run("llama3.2", []string{"gateway", "--status"}); err == nil {
|
||||
t.Fatal("expected run failure")
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Fatal("expected onboarding flag to remain unset after failed run")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEdit(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
@@ -487,16 +359,19 @@ func TestOpenclawEditSchemaFields(t *testing.T) {
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
|
||||
// Verify base schema fields (always set regardless of API availability)
|
||||
if entry["id"] != "llama3.2" {
|
||||
t.Errorf("id = %v, want llama3.2", entry["id"])
|
||||
}
|
||||
if entry["name"] != "llama3.2" {
|
||||
t.Errorf("name = %v, want llama3.2", entry["name"])
|
||||
// Verify required schema fields
|
||||
if entry["reasoning"] != false {
|
||||
t.Error("reasoning should be false")
|
||||
}
|
||||
if entry["input"] == nil {
|
||||
t.Error("input should be set")
|
||||
}
|
||||
if entry["contextWindow"] == nil {
|
||||
t.Error("contextWindow should be set")
|
||||
}
|
||||
if entry["maxTokens"] == nil {
|
||||
t.Error("maxTokens should be set")
|
||||
}
|
||||
cost := entry["cost"].(map[string]any)
|
||||
if cost["cacheRead"] == nil {
|
||||
t.Error("cost.cacheRead should be set")
|
||||
@@ -1001,589 +876,3 @@ func TestOpenclawOnboarded(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawGatewayInfo(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("returns defaults when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "" {
|
||||
t.Errorf("expected empty token, got %q", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reads token and port from config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"gateway": {
|
||||
"port": 9999,
|
||||
"auth": {"mode": "token", "token": "my-secret"}
|
||||
}
|
||||
}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "my-secret" {
|
||||
t.Errorf("expected token %q, got %q", "my-secret", token)
|
||||
}
|
||||
if port != 9999 {
|
||||
t.Errorf("expected port 9999, got %d", port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses default port when not in config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"gateway": {"auth": {"token": "tok"}}
|
||||
}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "tok" {
|
||||
t.Errorf("expected token %q, got %q", "tok", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port %d, got %d", defaultGatewayPort, port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to legacy clawdbot config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
|
||||
"gateway": {"port": 12345, "auth": {"token": "legacy-token"}}
|
||||
}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "legacy-token" {
|
||||
t.Errorf("expected token %q, got %q", "legacy-token", token)
|
||||
}
|
||||
if port != 12345 {
|
||||
t.Errorf("expected port 12345, got %d", port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "" {
|
||||
t.Errorf("expected empty token, got %q", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port, got %d", port)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles missing gateway section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
|
||||
|
||||
token, port := c.gatewayInfo()
|
||||
if token != "" {
|
||||
t.Errorf("expected empty token, got %q", token)
|
||||
}
|
||||
if port != defaultGatewayPort {
|
||||
t.Errorf("expected default port, got %d", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrintOpenclawReady(t *testing.T) {
|
||||
t.Run("includes port in URL", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "", 9999, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "localhost:9999") {
|
||||
t.Errorf("expected port 9999 in output, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "#token=") {
|
||||
t.Error("should not include token fragment when token is empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("URL-escapes token", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "my token&special=chars", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
escaped := url.QueryEscape("my token&special=chars")
|
||||
if !strings.Contains(output, "#token="+escaped) {
|
||||
t.Errorf("expected URL-escaped token %q in output, got:\n%s", escaped, output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple token is not mangled", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "#token=ollama") {
|
||||
t.Errorf("expected #token=ollama in output, got:\n%s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("includes web UI hint", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Open the Web UI") {
|
||||
t.Errorf("expected web UI hint in output, got:\n%s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("first launch shows quick start tips", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, true)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
for _, want := range []string{"/help", "channels", "skills", "gateway"} {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Errorf("expected %q in first-launch output, got:\n%s", want, output)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("subsequent launch shows single tip", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
old := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
printOpenclawReady("openclaw", "ollama", defaultGatewayPort, false)
|
||||
|
||||
w.Close()
|
||||
os.Stderr = old
|
||||
buf.ReadFrom(r)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "Tip:") {
|
||||
t.Errorf("expected single tip line, got:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "Quick start") {
|
||||
t.Errorf("should not show quick start on subsequent launch")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawModelConfig(t *testing.T) {
|
||||
t.Run("nil client returns base config", func(t *testing.T) {
|
||||
cfg, _ := openclawModelConfig(context.Background(), nil, "llama3.2")
|
||||
|
||||
if cfg["id"] != "llama3.2" {
|
||||
t.Errorf("id = %v, want llama3.2", cfg["id"])
|
||||
}
|
||||
if cfg["name"] != "llama3.2" {
|
||||
t.Errorf("name = %v, want llama3.2", cfg["name"])
|
||||
}
|
||||
if cfg["cost"] == nil {
|
||||
t.Error("cost should be set")
|
||||
}
|
||||
// Should not have capability fields without API
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set without API")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set without API")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets vision input when model has vision capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{"llama.context_length":4096}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "llava:7b")
|
||||
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 2 {
|
||||
t.Errorf("input = %v, want [text image]", cfg["input"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 1 {
|
||||
t.Errorf("input = %v, want [text]", cfg["input"])
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set for non-thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "qwq")
|
||||
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true for thinking model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("extracts context window from model info", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "llama3.2")
|
||||
|
||||
if cfg["contextWindow"] != 131072 {
|
||||
t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles all capabilities together", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "qwen3-vision")
|
||||
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 2 {
|
||||
t.Errorf("input = %v, want [text image]", cfg["input"])
|
||||
}
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true")
|
||||
}
|
||||
if cfg["contextWindow"] != 32768 {
|
||||
t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns base config when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "missing-model")
|
||||
|
||||
if cfg["id"] != "missing-model" {
|
||||
t.Errorf("id = %v, want missing-model", cfg["id"])
|
||||
}
|
||||
// Should still have input (default)
|
||||
if cfg["input"] == nil {
|
||||
t.Error("input should always be set")
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set when show fails")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set when show fails")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("times out slow show and returns base config", func(t *testing.T) {
|
||||
oldTimeout := openclawModelShowTimeout
|
||||
openclawModelShowTimeout = 50 * time.Millisecond
|
||||
t.Cleanup(func() { openclawModelShowTimeout = oldTimeout })
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{"llama.context_length":4096}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
start := time.Now()
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "slow-model")
|
||||
elapsed := time.Since(start)
|
||||
if elapsed >= 250*time.Millisecond {
|
||||
t.Fatalf("openclawModelConfig took too long: %v", elapsed)
|
||||
}
|
||||
if cfg["id"] != "slow-model" {
|
||||
t.Errorf("id = %v, want slow-model", cfg["id"])
|
||||
}
|
||||
if _, ok := cfg["reasoning"]; ok {
|
||||
t.Error("reasoning should not be set on timeout")
|
||||
}
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set on timeout")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips zero context length", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, _ := openclawModelConfig(context.Background(), client, "test-model")
|
||||
|
||||
if _, ok := cfg["contextWindow"]; ok {
|
||||
t.Error("contextWindow should not be set for zero value")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model uses hardcoded limits", func(t *testing.T) {
|
||||
// Use a model name that's in cloudModelLimits and make the server
|
||||
// report it as a remote/cloud model
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.5"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.5:cloud")
|
||||
|
||||
if !isCloud {
|
||||
t.Error("expected isCloud = true for cloud model")
|
||||
}
|
||||
if cfg["contextWindow"] != 204_800 {
|
||||
t.Errorf("contextWindow = %v, want 204800", cfg["contextWindow"])
|
||||
}
|
||||
if cfg["maxTokens"] != 128_000 {
|
||||
t.Errorf("maxTokens = %v, want 128000", cfg["maxTokens"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model with vision capability gets image input", func(t *testing.T) {
|
||||
// Regression test: cloud models must not skip capability detection.
|
||||
// A cloud model that reports vision capability should have input: [text, image].
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{},"remote_model":"qwen3-vl"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwen3-vl:235b-cloud")
|
||||
|
||||
if !isCloud {
|
||||
t.Error("expected isCloud = true for cloud vision model")
|
||||
}
|
||||
input, ok := cfg["input"].([]any)
|
||||
if !ok || len(input) != 2 {
|
||||
t.Errorf("input = %v, want [text image] for cloud vision model", cfg["input"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model with thinking capability gets reasoning flag", func(t *testing.T) {
|
||||
// Regression test: cloud models must not skip capability detection.
|
||||
// A cloud model that reports thinking capability should have reasoning: true.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{},"remote_model":"qwq-cloud"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "qwq:cloud")
|
||||
|
||||
if !isCloud {
|
||||
t.Error("expected isCloud = true for cloud thinking model")
|
||||
}
|
||||
if cfg["reasoning"] != true {
|
||||
t.Error("expected reasoning = true for cloud thinking model")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationOnboarded(t *testing.T) {
|
||||
t.Run("returns false when not set", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err == nil && integrationConfig.Onboarded {
|
||||
t.Error("expected false for fresh config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns true after integrationOnboarded", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
|
||||
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("is case insensitive", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
|
||||
|
||||
if err := integrationOnboarded("OpenClaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true when set with different case")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves existing integration data", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
os.MkdirAll(filepath.Join(tmpDir, ".ollama"), 0o755)
|
||||
|
||||
if err := SaveIntegration("openclaw", []string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := integrationOnboarded("openclaw"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify onboarded is set
|
||||
integrationConfig, err := loadIntegration("openclaw")
|
||||
if err != nil || !integrationConfig.Onboarded {
|
||||
t.Error("expected true after integrationOnboarded")
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
model := IntegrationModel("openclaw")
|
||||
if model != "llama3.2" {
|
||||
t.Errorf("expected first model llama3.2, got %q", model)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -10,11 +10,10 @@ import (
|
||||
|
||||
// ANSI escape sequences for terminal formatting.
|
||||
const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiYellow = "\033[33m"
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
)
|
||||
|
||||
// ErrCancelled is returned when the user cancels a selection.
|
||||
|
||||
@@ -415,12 +415,6 @@ type multiSelectorModel struct {
|
||||
cancelled bool
|
||||
confirmed bool
|
||||
width int
|
||||
|
||||
// multi enables full multi-select editing mode. The zero value (false)
|
||||
// shows a single-select picker where Enter adds the chosen model to
|
||||
// the existing list. Tab toggles between modes.
|
||||
multi bool
|
||||
singleAdd string // model picked in single mode
|
||||
}
|
||||
|
||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||
@@ -435,23 +429,13 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
// Reverse order so preChecked[0] (the current default) ends up last
|
||||
// in checkOrder, matching the "last checked = default" convention.
|
||||
for i := len(preChecked) - 1; i >= 0; i-- {
|
||||
if idx, ok := m.itemIndex[preChecked[i]]; ok {
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
// Position cursor on the current default model
|
||||
if len(preChecked) > 0 {
|
||||
if idx, ok := m.itemIndex[preChecked[0]]; ok {
|
||||
m.cursor = idx
|
||||
m.updateScroll(m.otherStart())
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -562,25 +546,14 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyTab:
|
||||
m.multi = !m.multi
|
||||
|
||||
case tea.KeyEnter:
|
||||
if !m.multi {
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.singleAdd = filtered[m.cursor].Name
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
} else if len(m.checkOrder) > 0 {
|
||||
if len(m.checkOrder) > 0 {
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case tea.KeySpace:
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
m.toggleItem()
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
@@ -619,9 +592,7 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// On some terminals (e.g. Windows PowerShell), space arrives as
|
||||
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
|
||||
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
m.toggleItem()
|
||||
} else {
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
@@ -633,19 +604,6 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
if item.Description != "" {
|
||||
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
@@ -657,7 +615,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
|
||||
}
|
||||
|
||||
suffix := ""
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
|
||||
suffix = " " + selectorDefaultTagStyle.Render("(default)")
|
||||
}
|
||||
|
||||
@@ -679,11 +637,6 @@ func (m multiSelectorModel) View() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
renderItem := m.renderSingleItem
|
||||
if m.multi {
|
||||
renderItem = m.renderMultiItem
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
@@ -708,7 +661,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -731,7 +684,7 @@ func (m multiSelectorModel) View() string {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
for _, idx := range recItems {
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -751,7 +704,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(otherItems) {
|
||||
break
|
||||
}
|
||||
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
}
|
||||
|
||||
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -763,18 +716,15 @@ func (m multiSelectorModel) View() string {
|
||||
|
||||
s.WriteString("\n")
|
||||
|
||||
if !m.multi {
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
|
||||
|
||||
result := s.String()
|
||||
if m.width > 0 {
|
||||
@@ -797,28 +747,18 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
|
||||
}
|
||||
|
||||
fm := finalModel.(multiSelectorModel)
|
||||
if fm.cancelled || !fm.confirmed {
|
||||
if fm.cancelled {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
// Single-add mode: prepend the picked model, keep existing models deduped
|
||||
if fm.singleAdd != "" {
|
||||
result := []string{fm.singleAdd}
|
||||
for _, name := range preChecked {
|
||||
if name != fm.singleAdd {
|
||||
result = append(result, name)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
if !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
// Multi-edit mode: last checked is default (first in result)
|
||||
last := fm.checkOrder[len(fm.checkOrder)-1]
|
||||
result := []string{fm.items[last].Name}
|
||||
var result []string
|
||||
for _, idx := range fm.checkOrder {
|
||||
if idx != last {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -539,7 +539,6 @@ func TestMultiView_CursorIndicator(t *testing.T) {
|
||||
|
||||
func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "[x]") {
|
||||
@@ -551,18 +550,11 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMultiView_DefaultTag(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
|
||||
m.multi = true
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Error("should have (default) tag")
|
||||
}
|
||||
// preChecked[0] ("a") should be the default (last in checkOrder)
|
||||
aIdx := strings.Index(content, "a")
|
||||
defaultIdx := strings.Index(content, "(default)")
|
||||
if defaultIdx < aIdx {
|
||||
t.Error("(default) tag should appear after 'a' (the current default)")
|
||||
t.Error("first checked item should have (default) tag")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -593,7 +585,6 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
|
||||
|
||||
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeySpace
|
||||
@@ -610,7 +601,6 @@ func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
||||
|
||||
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
|
||||
@@ -628,161 +618,6 @@ func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Single-add mode ---
|
||||
|
||||
func TestMulti_StartsInSingleMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
if m.multi {
|
||||
t.Error("should start in single mode (multi=false)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
content := m.View()
|
||||
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
|
||||
t.Error("single mode should not show checkboxes")
|
||||
}
|
||||
if !strings.Contains(content, "▸") {
|
||||
t.Error("single mode should show cursor indicator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.cursor = 1
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if m.singleAdd != "b" {
|
||||
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
|
||||
}
|
||||
if !m.confirmed {
|
||||
t.Error("should set confirmed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space in single mode should not toggle items")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space rune in single mode should not toggle items")
|
||||
}
|
||||
if m.filter != "" {
|
||||
t.Error("space rune in single mode should not add to filter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_TabTogglesMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
|
||||
if m.multi {
|
||||
t.Fatal("should start in single mode")
|
||||
}
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if !m.multi {
|
||||
t.Error("tab should switch to multi mode")
|
||||
}
|
||||
|
||||
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if m.multi {
|
||||
t.Error("tab should switch back to single mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab add multiple") {
|
||||
t.Error("single mode should show 'tab add multiple' in help")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_MultiModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab select single") {
|
||||
t.Error("multi mode should show 'tab select single' in help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- preChecked initialization order ---
|
||||
|
||||
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
|
||||
// preChecked[0] ("a") is the current default and should end up
|
||||
// last in checkOrder so it gets the (default) tag.
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
|
||||
|
||||
if len(m.checkOrder) != 3 {
|
||||
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
|
||||
}
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "a" {
|
||||
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_CursorOnDefaultModel(t *testing.T) {
|
||||
// preChecked[0] ("b") is the default; cursor should start on it
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
|
||||
|
||||
if m.cursor != 1 {
|
||||
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Multi-mode last-checked is default ---
|
||||
|
||||
func TestMulti_LastCheckedIsDefault(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
|
||||
m.multi = true
|
||||
|
||||
// Check "alpha" then "gamma"
|
||||
m.cursor = 0
|
||||
m.toggleItem()
|
||||
m.cursor = 2
|
||||
m.toggleItem()
|
||||
|
||||
// Last checked ("gamma") should be at the end of checkOrder
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "gamma" {
|
||||
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
|
||||
// The (default) tag renders based on checkOrder[len-1]
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Fatal("should show (default) tag")
|
||||
}
|
||||
// "alpha" line should NOT have the default tag
|
||||
for _, line := range strings.Split(content, "\n") {
|
||||
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
|
||||
t.Error("'alpha' (first checked) should not have (default) tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Key message helpers for testing
|
||||
|
||||
type keyType = int
|
||||
|
||||
@@ -429,24 +429,8 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
if m.multiModalSelector.confirmed {
|
||||
var selected []string
|
||||
if m.multiModalSelector.singleAdd != "" {
|
||||
// Single-add mode: prepend picked model, keep existing deduped
|
||||
selected = []string{m.multiModalSelector.singleAdd}
|
||||
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
|
||||
if name != m.multiModalSelector.singleAdd {
|
||||
selected = append(selected, name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Last checked is default (first in result)
|
||||
co := m.multiModalSelector.checkOrder
|
||||
last := co[len(co)-1]
|
||||
selected = []string{m.multiModalSelector.items[last].Name}
|
||||
for _, idx := range co {
|
||||
if idx != last {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
}
|
||||
}
|
||||
for _, idx := range m.multiModalSelector.checkOrder {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
}
|
||||
if len(selected) > 0 {
|
||||
m.changeModels = selected
|
||||
@@ -524,7 +508,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case "enter", " ":
|
||||
item := m.items[m.cursor]
|
||||
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) && !config.AutoInstallable(item.integration) {
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
@@ -555,12 +539,6 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
item := m.items[m.cursor]
|
||||
if item.integration != "" || item.isRunModel {
|
||||
if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
// Auto-installable: select to trigger install flow
|
||||
m.selected = true
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
if item.integration != "" && config.IsEditorIntegration(item.integration) {
|
||||
@@ -624,11 +602,7 @@ func (m model) View() string {
|
||||
var modelSuffix string
|
||||
if item.integration != "" {
|
||||
if !isInstalled {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
title += " " + notInstalledStyle.Render("(install)")
|
||||
} else {
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
}
|
||||
title += " " + notInstalledStyle.Render("(not installed)")
|
||||
} else if m.cursor == i {
|
||||
if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
|
||||
modelSuffix = " " + modelStyle.Render("("+mdl+")")
|
||||
@@ -644,9 +618,7 @@ func (m model) View() string {
|
||||
|
||||
desc := item.description
|
||||
if !isInstalled && item.integration != "" && m.cursor == i {
|
||||
if config.AutoInstallable(item.integration) {
|
||||
desc = "Press enter to install"
|
||||
} else if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
||||
if hint := config.IntegrationInstallHint(item.integration); hint != "" {
|
||||
desc = hint
|
||||
} else {
|
||||
desc = "not installed"
|
||||
|
||||
48
docs/api.md
48
docs/api.md
@@ -15,6 +15,7 @@
|
||||
- [Push a Model](#push-a-model)
|
||||
- [Generate Embeddings](#generate-embeddings)
|
||||
- [List Running Models](#list-running-models)
|
||||
- [Usage](#usage)
|
||||
- [Version](#version)
|
||||
- [Experimental: Image Generation](#image-generation-experimental)
|
||||
|
||||
@@ -1854,6 +1855,53 @@ curl http://localhost:11434/api/embeddings -d '{
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
GET /api/usage
|
||||
```
|
||||
|
||||
Show aggregate usage statistics per model since the server started. All timestamps are UTC in RFC 3339 format.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/usage
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"start": "2025-01-27T20:00:00Z",
|
||||
"usage": [
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"requests": 5,
|
||||
"prompt_tokens": 130,
|
||||
"completion_tokens": 890
|
||||
},
|
||||
{
|
||||
"model": "deepseek-r1",
|
||||
"requests": 2,
|
||||
"prompt_tokens": 48,
|
||||
"completion_tokens": 312
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### Response fields
|
||||
|
||||
- `start`: when the server started tracking usage (UTC, RFC 3339)
|
||||
- `usage`: list of per-model usage statistics
|
||||
- `model`: model name
|
||||
- `requests`: total number of completed requests
|
||||
- `prompt_tokens`: total prompt tokens evaluated
|
||||
- `completion_tokens`: total completion tokens generated
|
||||
|
||||
## Version
|
||||
|
||||
```
|
||||
|
||||
@@ -4,65 +4,47 @@ title: OpenClaw
|
||||
|
||||
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
|
||||
|
||||
## Quick start
|
||||
## Install
|
||||
|
||||
Install [OpenClaw](https://openclaw.ai/)
|
||||
|
||||
```bash
|
||||
npm install -g openclaw@latest
|
||||
```
|
||||
|
||||
Then run the onboarding wizard:
|
||||
|
||||
```bash
|
||||
openclaw onboard --install-daemon
|
||||
```
|
||||
|
||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
Ollama handles everything automatically:
|
||||
|
||||
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
|
||||
2. **Security** — On the first launch, a security notice explains the risks of tool access
|
||||
3. **Model** — Pick a model from the selector (local or cloud)
|
||||
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
|
||||
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
|
||||
|
||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
||||
|
||||
## Configure without launching
|
||||
This configures OpenClaw to use Ollama and starts the gateway.
|
||||
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
|
||||
|
||||
To change the model without starting the gateway and TUI:
|
||||
|
||||
```bash
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch openclaw --config
|
||||
```
|
||||
|
||||
To use a specific model directly:
|
||||
## Recommended Models
|
||||
|
||||
```bash
|
||||
ollama launch openclaw --model kimi-k2.5:cloud
|
||||
```
|
||||
|
||||
If the gateway is already running, it restarts automatically to pick up the new model.
|
||||
|
||||
## Recommended models
|
||||
|
||||
**Cloud models**:
|
||||
|
||||
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
|
||||
- `glm-5:cloud` — Reasoning and code generation
|
||||
|
||||
**Local models:**
|
||||
|
||||
- `glm-4.7-flash` — Reasoning and code generation locally (~25 GB VRAM)
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Connect messaging apps
|
||||
|
||||
```bash
|
||||
openclaw configure --section channels
|
||||
```
|
||||
|
||||
Link WhatsApp, Telegram, Slack, Discord, or iMessage to chat with your local models from anywhere.
|
||||
|
||||
## Stopping the gateway
|
||||
|
||||
```bash
|
||||
openclaw gateway stop
|
||||
```
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
1
go.mod
1
go.mod
@@ -26,7 +26,6 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/klauspost/compress v1.18.3
|
||||
github.com/mattn/go-runewidth v0.0.16
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
|
||||
4
go.sum
4
go.sum
@@ -122,6 +122,7 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
|
||||
@@ -149,9 +150,8 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
|
||||
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
|
||||
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -497,17 +496,6 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
|
||||
|
||||
func ResponsesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.GetHeader("Content-Encoding") == "zstd" {
|
||||
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
c.Request.Body = io.NopCloser(reader)
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
var req openai.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -1239,102 +1238,3 @@ func TestImageEditsMiddleware(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func zstdCompress(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
w, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestResponsesMiddlewareZstd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
useZstd bool
|
||||
oversized bool
|
||||
wantCode int
|
||||
wantModel string
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "plain JSON",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd compressed",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd over max decompressed size",
|
||||
oversized: true,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedRequest *api.ChatRequest
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
var bodyReader io.Reader
|
||||
if tt.oversized {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
|
||||
} else if tt.useZstd {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
|
||||
} else {
|
||||
bodyReader = strings.NewReader(tt.body)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tt.useZstd || tt.oversized {
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.wantCode {
|
||||
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("expected captured request, got nil")
|
||||
}
|
||||
if capturedRequest.Model != tt.wantModel {
|
||||
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
|
||||
}
|
||||
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
|
||||
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,6 @@
|
||||
# This script installs Ollama on Linux and macOS.
|
||||
# It detects the current operating system architecture and installs the appropriate version of Ollama.
|
||||
|
||||
# Wrap script in main function so that a truncated partial download doesn't end
|
||||
# up executing half a script.
|
||||
main() {
|
||||
|
||||
set -eu
|
||||
|
||||
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||
@@ -450,6 +446,3 @@ fi
|
||||
|
||||
status "NVIDIA GPU ready."
|
||||
install_success
|
||||
}
|
||||
|
||||
main
|
||||
|
||||
@@ -91,6 +91,8 @@ type Server struct {
|
||||
aliasesOnce sync.Once
|
||||
aliases *store
|
||||
aliasesErr error
|
||||
lowVRAM bool
|
||||
usage *UsageTracker
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -289,6 +291,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
fn := func(resp api.GenerateResponse) error {
|
||||
if resp.Done {
|
||||
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
|
||||
}
|
||||
|
||||
resp.Model = origModel
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
@@ -595,6 +601,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
res.Context = tokens
|
||||
}
|
||||
|
||||
s.usage.Record(req.Model, cr.PromptEvalCount, cr.EvalCount)
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
@@ -1622,6 +1630,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||
|
||||
r.GET("/api/usage", s.UsageHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
@@ -1692,7 +1702,7 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
}
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
s := &Server{addr: ln.Addr(), usage: NewUsageTracker()}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
@@ -1927,6 +1937,10 @@ func (s *Server) SignoutHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (s *Server) UsageHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, s.usage.Stats())
|
||||
}
|
||||
|
||||
func (s *Server) PsHandler(c *gin.Context) {
|
||||
models := []api.ProcessModelResponse{}
|
||||
|
||||
@@ -2097,6 +2111,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
fn := func(resp api.ChatResponse) error {
|
||||
if resp.Done {
|
||||
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
|
||||
}
|
||||
|
||||
resp.Model = origModel
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
@@ -2317,6 +2335,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
|
||||
s.usage.Record(req.Model, r.PromptEvalCount, r.EvalCount)
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
|
||||
@@ -30,6 +30,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -224,6 +225,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
@@ -35,6 +35,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -220,6 +221,7 @@ func TestGenerateWithDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
@@ -88,19 +88,39 @@ func TestGenerateChatRemote(t *testing.T) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/api/chat" {
|
||||
t.Errorf("Expected path '/api/chat', got %s", r.URL.Path)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := api.ChatResponse{
|
||||
Model: "test",
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
switch r.URL.Path {
|
||||
case "/api/chat":
|
||||
resp := api.ChatResponse{
|
||||
Model: "test",
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 20,
|
||||
},
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
case "/api/generate":
|
||||
resp := api.GenerateResponse{
|
||||
Model: "test",
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 5,
|
||||
EvalCount: 15,
|
||||
},
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected path %s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer rs.Close()
|
||||
@@ -111,7 +131,7 @@ func TestGenerateChatRemote(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
||||
s := Server{}
|
||||
s := Server{usage: NewUsageTracker()}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-cloud",
|
||||
RemoteHost: rs.URL,
|
||||
@@ -159,6 +179,61 @@ func TestGenerateChatRemote(t *testing.T) {
|
||||
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remote chat usage tracking", func(t *testing.T) {
|
||||
stats := s.usage.Stats()
|
||||
found := false
|
||||
for _, m := range stats.Usage {
|
||||
if m.Model == "test-cloud" {
|
||||
found = true
|
||||
if m.Requests != 1 {
|
||||
t.Errorf("expected 1 request, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 10 {
|
||||
t.Errorf("expected 10 prompt tokens, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 20 {
|
||||
t.Errorf("expected 20 completion tokens, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected usage entry for test-cloud")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remote generate usage tracking", func(t *testing.T) {
|
||||
// Reset the tracker for a clean test
|
||||
s.usage = NewUsageTracker()
|
||||
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-cloud",
|
||||
Prompt: "hello",
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
stats := s.usage.Stats()
|
||||
found := false
|
||||
for _, m := range stats.Usage {
|
||||
if m.Model == "test-cloud" {
|
||||
found = true
|
||||
if m.Requests != 1 {
|
||||
t.Errorf("expected 1 request, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 5 {
|
||||
t.Errorf("expected 5 prompt tokens, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 15 {
|
||||
t.Errorf("expected 15 completion tokens, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected usage entry for test-cloud")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateChat(t *testing.T) {
|
||||
@@ -177,6 +252,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -894,6 +970,7 @@ func TestGenerate(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -1378,6 +1455,7 @@ func TestGenerateLogprobs(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -1558,6 +1636,7 @@ func TestChatLogprobs(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -1668,6 +1747,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -2114,6 +2194,7 @@ func TestGenerateUnload(t *testing.T) {
|
||||
var loadFnCalled bool
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -2215,6 +2296,7 @@ func TestGenerateWithImages(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -2393,6 +2475,7 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
@@ -255,6 +255,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -406,6 +407,7 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -588,6 +590,7 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
62
server/usage.go
Normal file
62
server/usage.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type ModelUsage struct {
|
||||
Requests int64
|
||||
PromptTokens int64
|
||||
CompletionTokens int64
|
||||
}
|
||||
|
||||
type UsageTracker struct {
|
||||
mu sync.Mutex
|
||||
start time.Time
|
||||
models map[string]*ModelUsage
|
||||
}
|
||||
|
||||
func NewUsageTracker() *UsageTracker {
|
||||
return &UsageTracker{
|
||||
start: time.Now().UTC(),
|
||||
models: make(map[string]*ModelUsage),
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UsageTracker) Record(model string, promptTokens, completionTokens int) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
m, ok := u.models[model]
|
||||
if !ok {
|
||||
m = &ModelUsage{}
|
||||
u.models[model] = m
|
||||
}
|
||||
|
||||
m.Requests++
|
||||
m.PromptTokens += int64(promptTokens)
|
||||
m.CompletionTokens += int64(completionTokens)
|
||||
}
|
||||
|
||||
func (u *UsageTracker) Stats() api.UsageResponse {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
byModel := make([]api.ModelUsageData, 0, len(u.models))
|
||||
for model, usage := range u.models {
|
||||
byModel = append(byModel, api.ModelUsageData{
|
||||
Model: model,
|
||||
Requests: usage.Requests,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
})
|
||||
}
|
||||
|
||||
return api.UsageResponse{
|
||||
Start: u.start,
|
||||
Usage: byModel,
|
||||
}
|
||||
}
|
||||
136
server/usage_test.go
Normal file
136
server/usage_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestUsageTrackerRecord(t *testing.T) {
|
||||
tracker := NewUsageTracker()
|
||||
|
||||
tracker.Record("model-a", 10, 20)
|
||||
tracker.Record("model-a", 5, 15)
|
||||
tracker.Record("model-b", 100, 200)
|
||||
|
||||
stats := tracker.Stats()
|
||||
|
||||
if len(stats.Usage) != 2 {
|
||||
t.Fatalf("expected 2 models, got %d", len(stats.Usage))
|
||||
}
|
||||
|
||||
lookup := make(map[string]api.ModelUsageData)
|
||||
for _, m := range stats.Usage {
|
||||
lookup[m.Model] = m
|
||||
}
|
||||
|
||||
a := lookup["model-a"]
|
||||
if a.Requests != 2 {
|
||||
t.Errorf("model-a requests: expected 2, got %d", a.Requests)
|
||||
}
|
||||
if a.PromptTokens != 15 {
|
||||
t.Errorf("model-a prompt tokens: expected 15, got %d", a.PromptTokens)
|
||||
}
|
||||
if a.CompletionTokens != 35 {
|
||||
t.Errorf("model-a completion tokens: expected 35, got %d", a.CompletionTokens)
|
||||
}
|
||||
|
||||
b := lookup["model-b"]
|
||||
if b.Requests != 1 {
|
||||
t.Errorf("model-b requests: expected 1, got %d", b.Requests)
|
||||
}
|
||||
if b.PromptTokens != 100 {
|
||||
t.Errorf("model-b prompt tokens: expected 100, got %d", b.PromptTokens)
|
||||
}
|
||||
if b.CompletionTokens != 200 {
|
||||
t.Errorf("model-b completion tokens: expected 200, got %d", b.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTrackerConcurrent(t *testing.T) {
|
||||
tracker := NewUsageTracker()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 100 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.Record("model-a", 1, 2)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
stats := tracker.Stats()
|
||||
if len(stats.Usage) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(stats.Usage))
|
||||
}
|
||||
|
||||
m := stats.Usage[0]
|
||||
if m.Requests != 100 {
|
||||
t.Errorf("requests: expected 100, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 100 {
|
||||
t.Errorf("prompt tokens: expected 100, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 200 {
|
||||
t.Errorf("completion tokens: expected 200, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTrackerStart(t *testing.T) {
|
||||
tracker := NewUsageTracker()
|
||||
|
||||
stats := tracker.Stats()
|
||||
if stats.Start.IsZero() {
|
||||
t.Error("expected non-zero start time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
}
|
||||
|
||||
s.usage.Record("llama3", 50, 100)
|
||||
s.usage.Record("llama3", 25, 50)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/usage", nil)
|
||||
|
||||
s.UsageHandler(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.UsageResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Usage) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(resp.Usage))
|
||||
}
|
||||
|
||||
m := resp.Usage[0]
|
||||
if m.Model != "llama3" {
|
||||
t.Errorf("expected model llama3, got %s", m.Model)
|
||||
}
|
||||
if m.Requests != 2 {
|
||||
t.Errorf("expected 2 requests, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 75 {
|
||||
t.Errorf("expected 75 prompt tokens, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 150 {
|
||||
t.Errorf("expected 150 completion tokens, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
@@ -18,9 +18,7 @@
|
||||
|
||||
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
|
||||
handle->ctx = (void*) DLOPEN(path);
|
||||
if (handle->ctx == NULL) {
|
||||
return 1;
|
||||
}
|
||||
CHECK(handle->ctx != NULL);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -55,30 +55,6 @@ func tryLoadFromDir(dir string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// tryLoadByName attempts to load the library using just its name,
|
||||
// allowing the system to use rpath, LD_LIBRARY_PATH, or standard search paths.
|
||||
// Returns true if the library was successfully loaded.
|
||||
func tryLoadByName() bool {
|
||||
libraryName := "libmlxc.dylib"
|
||||
if runtime.GOOS == "linux" {
|
||||
libraryName = "libmlxc.so"
|
||||
}
|
||||
|
||||
cPath := C.CString(libraryName)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
return false
|
||||
}
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
@@ -97,11 +73,6 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// Try loading via rpath/standard library search
|
||||
if tryLoadByName() {
|
||||
return
|
||||
}
|
||||
|
||||
// Build search paths: executable directory, then build directories
|
||||
var searchDirs []string
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
// Model is the interface that model implementations must satisfy.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
@@ -125,5 +126,13 @@ func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
return flushValidUTF8Prefix(b)
|
||||
if text := b.String(); utf8.ValidString(text) {
|
||||
b.Reset()
|
||||
return text
|
||||
} else if b.Len() >= utf8.UTFMax {
|
||||
b.Reset()
|
||||
return text
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -12,12 +12,12 @@ import (
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix
|
||||
// currently buffered, leaving any incomplete trailing bytes in place.
|
||||
func flushValidUTF8Prefix(b *bytes.Buffer) string {
|
||||
data := b.Bytes()
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
prefix := validUTF8PrefixLen(data)
|
||||
if prefix == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
text := string(data[:prefix])
|
||||
b.Next(prefix)
|
||||
return text
|
||||
}
|
||||
|
||||
func validUTF8PrefixLen(data []byte) int {
|
||||
i := 0
|
||||
prefix := 0
|
||||
for i < len(data) {
|
||||
r, size := utf8.DecodeRune(data[i:])
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
if !utf8.FullRune(data[i:]) {
|
||||
break
|
||||
}
|
||||
|
||||
// Invalid UTF-8 byte; consume one byte to guarantee forward progress.
|
||||
i++
|
||||
prefix = i
|
||||
continue
|
||||
}
|
||||
|
||||
i += size
|
||||
prefix = i
|
||||
}
|
||||
|
||||
return prefix
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
|
||||
b.Write([]byte{0xE3, 0x81})
|
||||
if got := flushValidUTF8Prefix(&b); got != "" {
|
||||
t.Fatalf("first flush = %q, want empty", got)
|
||||
}
|
||||
|
||||
b.Write([]byte{0x93, 0xE3})
|
||||
if got := flushValidUTF8Prefix(&b); got != "こ" {
|
||||
t.Fatalf("second flush = %q, want %q", got, "こ")
|
||||
}
|
||||
|
||||
if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) {
|
||||
t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3})
|
||||
}
|
||||
|
||||
b.Write([]byte{0x82, 0x93})
|
||||
if got := flushValidUTF8Prefix(&b); got != "ん" {
|
||||
t.Fatalf("third flush = %q, want %q", got, "ん")
|
||||
}
|
||||
|
||||
if b.Len() != 0 {
|
||||
t.Fatalf("buffer not empty after third flush: %d", b.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushValidUTF8Prefix_ValidText(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
b.WriteString("hello 世界")
|
||||
|
||||
if got := flushValidUTF8Prefix(&b); got != "hello 世界" {
|
||||
t.Fatalf("flush = %q, want %q", got, "hello 世界")
|
||||
}
|
||||
|
||||
if b.Len() != 0 {
|
||||
t.Fatalf("buffer not empty after flush: %d", b.Len())
|
||||
}
|
||||
}
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
|
||||
//
|
||||
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
|
||||
// - GPT-2 byte-level encoding (OpenAI tiktoken)
|
||||
// - HuggingFace tokenizer.json pretokenizer patterns
|
||||
// - SentencePiece ▁-style space handling
|
||||
|
||||
package tokenizer
|
||||
|
||||
import "regexp"
|
||||
|
||||
// TokenizerType identifies the tokenization algorithm
|
||||
type TokenizerType int
|
||||
|
||||
const (
|
||||
TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
|
||||
TokenizerSentencePiece // SentencePiece with ▁ for spaces
|
||||
)
|
||||
|
||||
// Vocabulary holds the tokenizer vocabulary and merges
|
||||
type Vocabulary struct {
|
||||
Values []string
|
||||
Reverse map[string]int32
|
||||
Merges map[string]int
|
||||
|
||||
BOS int32
|
||||
EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has <eos> and <end_of_turn>)
|
||||
PAD int32 // Padding token (often <|endoftext|> or <pad>)
|
||||
AddBOS bool
|
||||
AddEOS bool
|
||||
|
||||
// Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
|
||||
byteTokens [256]int32
|
||||
}
|
||||
|
||||
// Tokenizer handles BPE and SentencePiece tokenization
|
||||
type Tokenizer struct {
|
||||
vocab *Vocabulary
|
||||
pretokenizer *regexp.Regexp
|
||||
specialTokens map[string]int32 // Special tokens for direct lookup
|
||||
sortedSpecialTokens []string // Special tokens sorted by length, longest first
|
||||
typ TokenizerType // Algorithm type
|
||||
}
|
||||
|
||||
// Precomputed GPT-2 byte-level encoding table
|
||||
// Maps byte values to their encoded rune equivalents
|
||||
var byteToRune [256]rune
|
||||
|
||||
func init() {
|
||||
for b := 0; b < 256; b++ {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
byteToRune[b] = r
|
||||
}
|
||||
}
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (t *Tokenizer) VocabSize() int {
|
||||
return len(t.vocab.Values)
|
||||
}
|
||||
|
||||
// BOS returns the beginning of sequence token ID
|
||||
func (t *Tokenizer) BOS() int32 {
|
||||
return t.vocab.BOS
|
||||
}
|
||||
|
||||
// EOS returns the first end of sequence token ID (for backwards compatibility)
|
||||
func (t *Tokenizer) EOS() int32 {
|
||||
if len(t.vocab.EOS) > 0 {
|
||||
return t.vocab.EOS[0]
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// EOSTokens returns all end of sequence token IDs
|
||||
func (t *Tokenizer) EOSTokens() []int32 {
|
||||
return t.vocab.EOS
|
||||
}
|
||||
|
||||
// PAD returns the padding token ID, or -1 if not set
|
||||
func (t *Tokenizer) PAD() int32 {
|
||||
return t.vocab.PAD
|
||||
}
|
||||
|
||||
// IsEOS returns true if the token ID is an end of sequence token
|
||||
func (t *Tokenizer) IsEOS(id int32) bool {
|
||||
for _, eos := range t.vocab.EOS {
|
||||
if id == eos {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetSpecialToken returns the token ID for a special token string
|
||||
func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
|
||||
id, ok := t.specialTokens[name]
|
||||
return id, ok
|
||||
}
|
||||
@@ -1,251 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
benchmarkSinkIDs []int32
|
||||
benchmarkSinkStr string
|
||||
benchmarkSinkTok *Tokenizer
|
||||
)
|
||||
|
||||
const benchmarkWordPieceJSON = `{
|
||||
"model": {
|
||||
"type": "WordPiece",
|
||||
"vocab": {
|
||||
"[UNK]": 0,
|
||||
"hello": 1,
|
||||
"##world": 2,
|
||||
"##ly": 3,
|
||||
"##hello": 4
|
||||
}
|
||||
},
|
||||
"added_tokens": []
|
||||
}`
|
||||
|
||||
const benchmarkSentencePieceJSON = `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {
|
||||
"\u2581": 0,
|
||||
"h": 1,
|
||||
"e": 2,
|
||||
"l": 3,
|
||||
"o": 4,
|
||||
"w": 5,
|
||||
"r": 6,
|
||||
"d": 7,
|
||||
"<0x0A>": 8
|
||||
},
|
||||
"merges": []
|
||||
},
|
||||
"decoder": {
|
||||
"type": "Sequence",
|
||||
"decoders": [
|
||||
{
|
||||
"type": "Replace",
|
||||
"pattern": {
|
||||
"String": "\u2581"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"added_tokens": []
|
||||
}`
|
||||
|
||||
func benchmarkMiniLlamaPath(tb testing.TB) string {
|
||||
tb.Helper()
|
||||
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
tb.Fatal("failed to resolve benchmark file path")
|
||||
}
|
||||
|
||||
return filepath.Join(filepath.Dir(filename), "..", "imagegen", "tokenizer", "testdata", "mini_llama.json")
|
||||
}
|
||||
|
||||
func benchmarkLoadMiniLlama(tb testing.TB) *Tokenizer {
|
||||
tb.Helper()
|
||||
|
||||
data := benchmarkLoadMiniLlamaBytes(tb)
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to load mini llama tokenizer: %v", err)
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func benchmarkLoadMiniLlamaBytes(tb testing.TB) []byte {
|
||||
tb.Helper()
|
||||
|
||||
data, err := os.ReadFile(benchmarkMiniLlamaPath(tb))
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to read mini llama tokenizer: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func benchmarkLoadFromBytes(tb testing.TB, data []byte) *Tokenizer {
|
||||
tb.Helper()
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to load tokenizer from bytes: %v", err)
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerEncodeBPE(b *testing.B) {
|
||||
tok := benchmarkLoadMiniLlama(b)
|
||||
|
||||
inputs := []struct {
|
||||
name string
|
||||
text string
|
||||
}{
|
||||
{name: "short", text: "Hello, world!"},
|
||||
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
|
||||
{name: "long_sequential", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 80)},
|
||||
{name: "long_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
|
||||
{name: "huge_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)},
|
||||
{name: "special_tokens", text: "<|begin_of_text|>system\nYou are concise.<|end_of_text|>"},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
b.Run(input.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(input.text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkIDs = tok.Encode(input.text, false)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerDecodeBPE(b *testing.B) {
|
||||
tok := benchmarkLoadMiniLlama(b)
|
||||
|
||||
inputs := []struct {
|
||||
name string
|
||||
text string
|
||||
}{
|
||||
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
|
||||
{name: "long", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
ids := tok.Encode(input.text, false)
|
||||
b.Run(input.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(input.text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkStr = tok.Decode(ids)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerLoadFromBytes(b *testing.B) {
|
||||
data := benchmarkLoadMiniLlamaBytes(b)
|
||||
|
||||
config := &TokenizerConfig{
|
||||
TokenizerConfigJSON: []byte(`{
|
||||
"bos_token": {"content": "<|begin_of_text|>"},
|
||||
"eos_token": {"content": "<|end_of_text|>"},
|
||||
"add_bos_token": true
|
||||
}`),
|
||||
GenerationConfigJSON: []byte(`{"bos_token_id": 128000, "eos_token_id": 128001}`),
|
||||
}
|
||||
|
||||
b.Run("without_config", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
b.Fatalf("LoadFromBytes failed: %v", err)
|
||||
}
|
||||
benchmarkSinkTok = tok
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("with_config", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok, err := LoadFromBytesWithConfig(data, config)
|
||||
if err != nil {
|
||||
b.Fatalf("LoadFromBytesWithConfig failed: %v", err)
|
||||
}
|
||||
benchmarkSinkTok = tok
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerEncodeWordPiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
|
||||
text := strings.Repeat("helloworldly", 16)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkIDs = tok.Encode(text, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerDecodeWordPiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
|
||||
text := strings.Repeat("helloworldly", 16)
|
||||
ids := tok.Encode(text, false)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkStr = tok.Decode(ids)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerEncodeSentencePiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
|
||||
text := strings.Repeat("hello world\n", 64)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkIDs = tok.Encode(text, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerDecodeSentencePiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
|
||||
text := strings.Repeat("hello world\n", 64)
|
||||
ids := tok.Encode(text, false)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkStr = tok.Decode(ids)
|
||||
}
|
||||
}
|
||||
@@ -1,175 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import "container/heap"
|
||||
|
||||
type bpeMergeNode struct {
|
||||
prev int
|
||||
next int
|
||||
token string
|
||||
}
|
||||
|
||||
type bpePair struct {
|
||||
left int
|
||||
right int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type bpePairHeap []*bpePair
|
||||
|
||||
func (h bpePairHeap) Len() int { return len(h) }
|
||||
|
||||
func (h bpePairHeap) Less(i, j int) bool {
|
||||
return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left)
|
||||
}
|
||||
|
||||
func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
|
||||
func (h *bpePairHeap) Push(x any) {
|
||||
*h = append(*h, x.(*bpePair))
|
||||
}
|
||||
|
||||
func (h *bpePairHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
*h = old[:n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
// encodeBPEMerge encodes using BPE merge algorithm.
|
||||
// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go:
|
||||
// merge the lowest-rank valid pair, then only recheck adjacent pairs.
|
||||
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
|
||||
runes := []rune(encoded)
|
||||
if len(runes) == 0 {
|
||||
return ids
|
||||
}
|
||||
|
||||
nodes := make([]bpeMergeNode, len(runes))
|
||||
for i := range runes {
|
||||
nodes[i] = bpeMergeNode{
|
||||
prev: i - 1,
|
||||
next: i + 1,
|
||||
token: string(runes[i]),
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(left, right int) *bpePair {
|
||||
if left < 0 || right >= len(nodes) {
|
||||
return nil
|
||||
}
|
||||
if nodes[left].token == "" || nodes[right].token == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
leftToken, rightToken := nodes[left].token, nodes[right].token
|
||||
rank, ok := t.vocab.Merges[leftToken+" "+rightToken]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
value := leftToken + rightToken
|
||||
if _, ok := t.vocab.Reverse[value]; !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &bpePair{
|
||||
left: left,
|
||||
right: right,
|
||||
rank: rank,
|
||||
value: value,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := bpePairHeap{}
|
||||
heap.Init(&pairs)
|
||||
for i := 0; i < len(runes)-1; i++ {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
heap.Push(&pairs, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for pairs.Len() > 0 {
|
||||
pair := heap.Pop(&pairs).(*bpePair)
|
||||
left, right := nodes[pair.left], nodes[pair.right]
|
||||
if left.token == "" || right.token == "" {
|
||||
continue
|
||||
}
|
||||
if left.next != pair.right || right.prev != pair.left {
|
||||
continue
|
||||
}
|
||||
if left.token+right.token != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
nodes[pair.left].token = pair.value
|
||||
nodes[pair.right].token = ""
|
||||
nodes[pair.left].next = right.next
|
||||
if right.next < len(nodes) {
|
||||
nodes[right.next].prev = pair.left
|
||||
}
|
||||
|
||||
if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil {
|
||||
heap.Push(&pairs, pair)
|
||||
}
|
||||
if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil {
|
||||
heap.Push(&pairs, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.token == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if id, ok := t.vocab.Reverse[node.token]; ok {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
ids = t.appendByteFallback(ids, node.token)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 {
|
||||
if t.typ == TokenizerBPE {
|
||||
for _, r := range token {
|
||||
if b, ok := decodeByteLevelRune(r); ok {
|
||||
if id := t.vocab.byteTokens[b]; id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens.
|
||||
for _, b := range []byte(token) {
|
||||
if id := t.vocab.byteTokens[b]; id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func decodeByteLevelRune(r rune) (byte, bool) {
|
||||
switch {
|
||||
case r >= 0x00 && r <= 0xFF:
|
||||
return byte(r), true
|
||||
case r == 0x0100:
|
||||
return 0x00, true
|
||||
case r == 0x0143:
|
||||
return 0x00ad, true
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
return byte(r - 0x0100), true
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
return byte(r - 0x00a2), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func equalIDs(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestEncodeRoundtripMiniLlama(t *testing.T) {
|
||||
tok := benchmarkLoadMiniLlama(t)
|
||||
|
||||
inputs := []string{
|
||||
"",
|
||||
"hello",
|
||||
"hello world",
|
||||
" hello world ",
|
||||
"don't we'll they're",
|
||||
"1234567890",
|
||||
"こんにちは世界",
|
||||
"Hello 世界",
|
||||
"func main() {}",
|
||||
"<|begin_of_text|>system\nYou are concise.<|end_of_text|>",
|
||||
strings.Repeat("The quick brown fox jumps over the lazy dog. ", 32),
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
ids := tok.Encode(input, false)
|
||||
got := tok.Decode(ids)
|
||||
if got != input {
|
||||
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitBySpecialTokensGreedyLongest(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"a": 0, "b": 1},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 2, "content": "<tag>", "special": true},
|
||||
{"id": 3, "content": "<tag>x", "special": true}
|
||||
]
|
||||
}`)
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
input := "a<tag>xb"
|
||||
want := []string{"a", "<tag>x", "b"}
|
||||
|
||||
got := tok.splitBySpecialTokens(input)
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("split length mismatch: got %v want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitBySpecialTokensFallbackWithoutCache(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"a": 0, "b": 1},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 2, "content": "<tag>", "special": true},
|
||||
{"id": 3, "content": "<tag>x", "special": true}
|
||||
]
|
||||
}`)
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
input := "a<tag>xb"
|
||||
want := []string{"a", "<tag>x", "b"}
|
||||
|
||||
// Simulate construction outside loader path where cache is not set.
|
||||
tok.sortedSpecialTokens = nil
|
||||
|
||||
got := tok.splitBySpecialTokens(input)
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("split length mismatch: got %v want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDeterministicAcrossGOMAXPROCS(t *testing.T) {
|
||||
tok := benchmarkLoadMiniLlama(t)
|
||||
|
||||
input := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)
|
||||
|
||||
prev := runtime.GOMAXPROCS(0)
|
||||
defer runtime.GOMAXPROCS(prev)
|
||||
|
||||
runtime.GOMAXPROCS(1)
|
||||
seq := tok.Encode(input, false)
|
||||
|
||||
if prev < 2 {
|
||||
runtime.GOMAXPROCS(2)
|
||||
} else {
|
||||
runtime.GOMAXPROCS(prev)
|
||||
}
|
||||
par := tok.Encode(input, false)
|
||||
|
||||
if !equalIDs(seq, par) {
|
||||
t.Fatalf("encode mismatch between sequential and parallel paths: seq=%d par=%d", len(seq), len(par))
|
||||
}
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Decode converts token IDs back to text
|
||||
func (t *Tokenizer) Decode(ids []int32) string {
|
||||
var sb strings.Builder
|
||||
|
||||
for _, id := range ids {
|
||||
if int(id) >= len(t.vocab.Values) {
|
||||
continue
|
||||
}
|
||||
|
||||
token := t.vocab.Values[id]
|
||||
|
||||
switch t.typ {
|
||||
case TokenizerSentencePiece:
|
||||
// SentencePiece style: replace ▁ with space, decode byte tokens
|
||||
token = strings.ReplaceAll(token, "▁", " ")
|
||||
// Handle byte fallback tokens like <0x0D>
|
||||
if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
|
||||
if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
|
||||
sb.WriteByte(byte(v))
|
||||
continue
|
||||
}
|
||||
}
|
||||
sb.WriteString(token)
|
||||
default:
|
||||
// GPT-2 BPE style: decode byte-level encoding
|
||||
for _, r := range token {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// Mirror GGML tokenizer behavior for NULL byte.
|
||||
// 0x00 is omitted during decode.
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
}
|
||||
|
||||
// Write as byte, not UTF-8 encoded rune
|
||||
sb.WriteByte(byte(r))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -1,289 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
encodeParallelMinInputBytes = 4 * 1024
|
||||
encodeParallelMinChunksPerWorker = 8
|
||||
)
|
||||
|
||||
type tokenMatch struct {
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
type encodeChunk struct {
|
||||
text string
|
||||
isSpecial bool
|
||||
}
|
||||
|
||||
// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
|
||||
func isNonNewlineWhitespace(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if r == '\n' || r == '\r' {
|
||||
return false
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
|
||||
func (t *Tokenizer) splitBySpecialTokens(s string) []string {
|
||||
if len(t.specialTokens) == 0 {
|
||||
return []string{s}
|
||||
}
|
||||
|
||||
tokens := t.sortedSpecialTokens
|
||||
if len(tokens) == 0 {
|
||||
// Fallback for tokenizers constructed outside the loaders.
|
||||
tokens = make([]string, 0, len(t.specialTokens))
|
||||
for tok := range t.specialTokens {
|
||||
tokens = append(tokens, tok)
|
||||
}
|
||||
sort.Slice(tokens, func(i, j int) bool {
|
||||
return len(tokens[i]) > len(tokens[j])
|
||||
})
|
||||
}
|
||||
|
||||
var result []string
|
||||
remaining := s
|
||||
|
||||
for len(remaining) > 0 {
|
||||
found := false
|
||||
for _, tok := range tokens {
|
||||
if strings.HasPrefix(remaining, tok) {
|
||||
result = append(result, tok)
|
||||
remaining = remaining[len(tok):]
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// Find next special token position
|
||||
nextPos := len(remaining)
|
||||
for _, tok := range tokens {
|
||||
if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
|
||||
nextPos = idx
|
||||
}
|
||||
}
|
||||
if nextPos > 0 {
|
||||
result = append(result, remaining[:nextPos])
|
||||
}
|
||||
remaining = remaining[nextPos:]
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func adjustWhitespaceBoundary(part string, curr, next *tokenMatch) {
|
||||
m := part[curr.start:curr.end]
|
||||
nextText := part[next.start:next.end]
|
||||
|
||||
if !isNonNewlineWhitespace(m) || len(nextText) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
firstRune, _ := utf8.DecodeRuneInString(nextText)
|
||||
if !unicode.IsLetter(firstRune) {
|
||||
return
|
||||
}
|
||||
|
||||
lastSpaceStart := curr.end
|
||||
for j := curr.end; j > curr.start; {
|
||||
r, size := utf8.DecodeLastRuneInString(part[curr.start:j])
|
||||
if unicode.IsSpace(r) {
|
||||
lastSpaceStart = j - size
|
||||
break
|
||||
}
|
||||
j -= size
|
||||
}
|
||||
if lastSpaceStart > curr.start {
|
||||
curr.end = lastSpaceStart
|
||||
next.start = lastSpaceStart
|
||||
} else {
|
||||
next.start = curr.start
|
||||
curr.end = curr.start
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tokenizer) forEachPartChunk(part string, fn func(encodeChunk)) {
|
||||
if _, ok := t.specialTokens[part]; ok {
|
||||
fn(encodeChunk{text: part, isSpecial: true})
|
||||
return
|
||||
}
|
||||
|
||||
if t.pretokenizer == nil {
|
||||
fn(encodeChunk{text: part, isSpecial: false})
|
||||
return
|
||||
}
|
||||
|
||||
re := t.pretokenizer
|
||||
offset := 0
|
||||
loc := re.FindStringIndex(part[offset:])
|
||||
if loc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
curr := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
|
||||
offset += loc[1]
|
||||
|
||||
for {
|
||||
loc = re.FindStringIndex(part[offset:])
|
||||
if loc == nil {
|
||||
if curr.end > curr.start {
|
||||
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
next := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
|
||||
offset += loc[1]
|
||||
|
||||
adjustWhitespaceBoundary(part, &curr, &next)
|
||||
|
||||
if curr.end > curr.start {
|
||||
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
|
||||
}
|
||||
curr = next
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tokenizer) appendEncodedChunk(ids []int32, c encodeChunk) []int32 {
|
||||
if c.isSpecial {
|
||||
if id, ok := t.specialTokens[c.text]; ok {
|
||||
return append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
return t.encodeChunkInto(c.text, ids)
|
||||
}
|
||||
|
||||
// Encode tokenizes text to token IDs.
|
||||
// Parallel encoding is used only for very large inputs with enough chunks per worker.
|
||||
func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
|
||||
// First: split by special tokens
|
||||
parts := t.splitBySpecialTokens(s)
|
||||
|
||||
// Fast path: encode sequentially without materializing chunk slices.
|
||||
if len(s) < encodeParallelMinInputBytes {
|
||||
var ids []int32
|
||||
for _, part := range parts {
|
||||
t.forEachPartChunk(part, func(c encodeChunk) {
|
||||
ids = t.appendEncodedChunk(ids, c)
|
||||
})
|
||||
}
|
||||
|
||||
if addBOS && t.vocab.BOS >= 0 {
|
||||
ids = append([]int32{t.vocab.BOS}, ids...)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// For large inputs collect chunks to enable parallel processing.
|
||||
var allChunks []encodeChunk
|
||||
for _, part := range parts {
|
||||
t.forEachPartChunk(part, func(c encodeChunk) {
|
||||
allChunks = append(allChunks, c)
|
||||
})
|
||||
}
|
||||
|
||||
// Encode chunks. Use the parallel path only when the chunk count is
|
||||
// large enough to amortize goroutine/synchronization overhead.
|
||||
useParallel := true
|
||||
numWorkers := runtime.GOMAXPROCS(0)
|
||||
if numWorkers > len(allChunks) {
|
||||
numWorkers = len(allChunks)
|
||||
}
|
||||
if numWorkers < 2 || len(allChunks) < numWorkers*encodeParallelMinChunksPerWorker {
|
||||
useParallel = false
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
if !useParallel {
|
||||
for _, c := range allChunks {
|
||||
ids = t.appendEncodedChunk(ids, c)
|
||||
}
|
||||
} else {
|
||||
chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
|
||||
results := make([][]int32, numWorkers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
start := i * chunksPer
|
||||
end := start + chunksPer
|
||||
if end > len(allChunks) {
|
||||
end = len(allChunks)
|
||||
}
|
||||
if start >= end {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(i int, chunks []encodeChunk) {
|
||||
defer wg.Done()
|
||||
var r []int32
|
||||
for _, c := range chunks {
|
||||
r = t.appendEncodedChunk(r, c)
|
||||
}
|
||||
results[i] = r
|
||||
}(i, allChunks[start:end])
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for _, r := range results {
|
||||
ids = append(ids, r...)
|
||||
}
|
||||
}
|
||||
|
||||
if addBOS && t.vocab.BOS >= 0 {
|
||||
ids = append([]int32{t.vocab.BOS}, ids...)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// encodeChunkInto appends encoded tokens to ids and returns the extended slice.
|
||||
// Uses BPE merge algorithm for both BPE and SentencePiece tokenization.
|
||||
func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
|
||||
if s == "" {
|
||||
return ids
|
||||
}
|
||||
|
||||
// Apply encoding transformation
|
||||
// SentencePiece: replace space with ▁
|
||||
// BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
|
||||
var encoded string
|
||||
if t.typ == TokenizerSentencePiece {
|
||||
encoded = strings.ReplaceAll(s, " ", "▁")
|
||||
} else {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(s) * 2)
|
||||
for i := 0; i < len(s); i++ {
|
||||
sb.WriteRune(byteToRune[s[i]])
|
||||
}
|
||||
encoded = sb.String()
|
||||
}
|
||||
|
||||
// Fast path: check if entire chunk is a single token
|
||||
if id, ok := t.vocab.Reverse[encoded]; ok {
|
||||
return append(ids, id)
|
||||
}
|
||||
|
||||
return t.encodeBPEMerge(encoded, ids)
|
||||
}
|
||||
@@ -1,207 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func llama32GGMLFixturePath(tb testing.TB, file string) string {
|
||||
tb.Helper()
|
||||
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
tb.Fatal("failed to resolve test file path")
|
||||
}
|
||||
|
||||
return filepath.Join(filepath.Dir(filename), "..", "..", "tokenizer", "testdata", "llama3.2", file)
|
||||
}
|
||||
|
||||
func loadLlama32FromGGMLFixture(tb testing.TB) *Tokenizer {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.Open(llama32GGMLFixturePath(tb, "encoder.json"))
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to open encoder.json: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
vocab := make(map[string]int32)
|
||||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||
tb.Fatalf("failed to decode encoder.json: %v", err)
|
||||
}
|
||||
|
||||
type addedToken struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
}
|
||||
var addedTokens []addedToken
|
||||
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||||
if _, ok := vocab[token]; !ok {
|
||||
id := int32(len(vocab))
|
||||
vocab[token] = id
|
||||
addedTokens = append(addedTokens, addedToken{ID: id, Content: token, Special: true})
|
||||
}
|
||||
}
|
||||
|
||||
mf, err := os.Open(llama32GGMLFixturePath(tb, "vocab.bpe"))
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to open vocab.bpe: %v", err)
|
||||
}
|
||||
defer mf.Close()
|
||||
|
||||
var merges []string
|
||||
scanner := bufio.NewScanner(mf)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" {
|
||||
merges = append(merges, line)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
tb.Fatalf("failed to read vocab.bpe: %v", err)
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges []string `json:"merges"`
|
||||
} `json:"model"`
|
||||
PreTokenizer struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
} `json:"pretokenizers"`
|
||||
} `json:"pre_tokenizer"`
|
||||
AddedTokens []addedToken `json:"added_tokens"`
|
||||
}{}
|
||||
|
||||
payload.Model.Type = "BPE"
|
||||
payload.Model.Vocab = vocab
|
||||
payload.Model.Merges = merges
|
||||
payload.PreTokenizer.Type = "Sequence"
|
||||
payload.PreTokenizer.Pretokenizers = []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
}{
|
||||
{
|
||||
Type: "Split",
|
||||
Pattern: struct {
|
||||
Regex string `json:"Regex"`
|
||||
}{
|
||||
Regex: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
},
|
||||
},
|
||||
}
|
||||
payload.AddedTokens = addedTokens
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to marshal synthetic tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to load tokenizer from fixture data: %v", err)
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func TestGGMLLlamaKnownEncodings(t *testing.T) {
|
||||
tok := loadLlama32FromGGMLFixture(t)
|
||||
|
||||
cases := map[string][]int32{
|
||||
"hello world": {15339, 1917},
|
||||
"hello <|end_of_text|>": {15339, 220, 128001},
|
||||
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||||
}
|
||||
|
||||
for input, want := range cases {
|
||||
got := tok.Encode(input, false)
|
||||
if !equalIDs(got, want) {
|
||||
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGGMLLlamaRepeatedZeros(t *testing.T) {
|
||||
tok := loadLlama32FromGGMLFixture(t)
|
||||
|
||||
cases := map[int][]int32{
|
||||
1: {15},
|
||||
2: {410},
|
||||
3: {931},
|
||||
4: {931, 15},
|
||||
5: {931, 410},
|
||||
6: {931, 931},
|
||||
7: {931, 931, 15},
|
||||
8: {931, 931, 410},
|
||||
9: {931, 931, 931},
|
||||
10: {931, 931, 931, 15},
|
||||
11: {931, 931, 931, 410},
|
||||
12: {931, 931, 931, 931},
|
||||
13: {931, 931, 931, 931, 15},
|
||||
14: {931, 931, 931, 931, 410},
|
||||
15: {931, 931, 931, 931, 931},
|
||||
16: {931, 931, 931, 931, 931, 15},
|
||||
17: {931, 931, 931, 931, 931, 410},
|
||||
}
|
||||
|
||||
for n, want := range cases {
|
||||
input := strings.Repeat("0", n)
|
||||
got := tok.Encode(input, false)
|
||||
if !equalIDs(got, want) {
|
||||
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGGMLLlamaRoundtripAndByteBehavior(t *testing.T) {
|
||||
tok := loadLlama32FromGGMLFixture(t)
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
}
|
||||
|
||||
for _, input := range cases {
|
||||
ids := tok.Encode(input, false)
|
||||
got := tok.Decode(ids)
|
||||
if got != input {
|
||||
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Match GGML tokenizer behavior: 0x00 is omitted when decoding.
|
||||
ids := tok.Encode(string(rune(0x00)), false)
|
||||
got := tok.Decode(ids)
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty decode for 0x00, got %q (ids=%v)", got, ids)
|
||||
}
|
||||
}
|
||||
@@ -1,458 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
|
||||
type TokenizerConfig struct {
|
||||
TokenizerConfigJSON []byte // tokenizer_config.json content
|
||||
GenerationConfigJSON []byte // generation_config.json content
|
||||
SpecialTokensMapJSON []byte // special_tokens_map.json content
|
||||
ConfigJSON []byte // config.json content
|
||||
}
|
||||
|
||||
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
|
||||
// This is useful when loading from blob storage where the file content is already in memory.
|
||||
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
|
||||
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
|
||||
func LoadFromBytes(data []byte) (*Tokenizer, error) {
|
||||
return loadFromTokenizerJSON(data)
|
||||
}
|
||||
|
||||
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
|
||||
// This is useful when loading from blob storage where companion config files are also blobs.
|
||||
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
|
||||
t, err := loadFromTokenizerJSON(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Apply special token configs from provided data
|
||||
loadSpecialTokenConfigFromBytes(t, config)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// loadFromTokenizerJSON parses tokenizer.json content from bytes.
|
||||
func loadFromTokenizerJSON(data []byte) (*Tokenizer, error) {
|
||||
|
||||
var raw struct {
|
||||
Model struct {
|
||||
Type string `json:"type"` // "BPE"
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
|
||||
} `json:"model"`
|
||||
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
|
||||
Decoder json.RawMessage `json:"decoder"`
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Covers SentencePiece and BPE models
|
||||
if raw.Model.Type != "BPE" {
|
||||
return nil, fmt.Errorf("unsupported tokenizer type: %s", raw.Model.Type)
|
||||
}
|
||||
|
||||
// Parse merges - can be []string (Llama) or [][]string (GPT-OSS).
|
||||
var mergesStrings []string
|
||||
if raw.Model.Merges != nil {
|
||||
var mergesArrays [][]string
|
||||
if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
|
||||
// Try array of arrays format
|
||||
if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse merges: %w", err)
|
||||
}
|
||||
// Convert [][]string to []string
|
||||
mergesStrings = make([]string, len(mergesArrays))
|
||||
for i, pair := range mergesArrays {
|
||||
if len(pair) != 2 {
|
||||
return nil, fmt.Errorf("failed to parse merges: expected merge pair of length 2, got %d", len(pair))
|
||||
}
|
||||
mergesStrings[i] = pair[0] + " " + pair[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build tokenizer
|
||||
t := &Tokenizer{
|
||||
vocab: &Vocabulary{
|
||||
Values: make([]string, len(raw.Model.Vocab)),
|
||||
Reverse: raw.Model.Vocab,
|
||||
Merges: make(map[string]int, len(mergesStrings)),
|
||||
BOS: -1,
|
||||
PAD: -1,
|
||||
},
|
||||
specialTokens: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Build values array
|
||||
for token, id := range raw.Model.Vocab {
|
||||
if int(id) >= len(t.vocab.Values) {
|
||||
newValues := make([]string, id+1)
|
||||
copy(newValues, t.vocab.Values)
|
||||
t.vocab.Values = newValues
|
||||
}
|
||||
t.vocab.Values[id] = token
|
||||
}
|
||||
|
||||
// Build merges map
|
||||
for i, merge := range mergesStrings {
|
||||
t.vocab.Merges[merge] = i
|
||||
}
|
||||
|
||||
// Add all added_tokens to vocabulary and special tokens map.
|
||||
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
|
||||
// they bypass BPE and get their own token ID. The "special" flag just indicates
|
||||
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
|
||||
// to treat all added_tokens as special to match HuggingFace behavior.
|
||||
for _, tok := range raw.AddedTokens {
|
||||
if int(tok.ID) >= len(t.vocab.Values) {
|
||||
newValues := make([]string, tok.ID+1)
|
||||
copy(newValues, t.vocab.Values)
|
||||
t.vocab.Values = newValues
|
||||
}
|
||||
t.vocab.Values[tok.ID] = tok.Content
|
||||
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
|
||||
}
|
||||
|
||||
// Precompute byte token IDs for <0xNN> fallback
|
||||
initByteTokens(t)
|
||||
|
||||
// Determine tokenizer type
|
||||
switch {
|
||||
case detectSentencePiece(raw.Decoder):
|
||||
t.typ = TokenizerSentencePiece
|
||||
default:
|
||||
t.typ = TokenizerBPE
|
||||
}
|
||||
|
||||
// Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
|
||||
if t.typ == TokenizerBPE {
|
||||
pattern := extractPretokenizer(raw.PreTokenizer)
|
||||
if pattern == "" {
|
||||
pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
|
||||
}
|
||||
re, err := regexp.Compile(rewritePatternForRE2(pattern))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
|
||||
}
|
||||
t.pretokenizer = re
|
||||
}
|
||||
|
||||
cacheSortedSpecialTokens(t)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func cacheSortedSpecialTokens(t *Tokenizer) {
|
||||
if len(t.specialTokens) == 0 {
|
||||
t.sortedSpecialTokens = nil
|
||||
return
|
||||
}
|
||||
|
||||
tokens := make([]string, 0, len(t.specialTokens))
|
||||
for tok := range t.specialTokens {
|
||||
tokens = append(tokens, tok)
|
||||
}
|
||||
sort.Slice(tokens, func(i, j int) bool {
|
||||
return len(tokens[i]) > len(tokens[j])
|
||||
})
|
||||
t.sortedSpecialTokens = tokens
|
||||
}
|
||||
|
||||
type specialTokenConfigData struct {
|
||||
tokenizerConfigJSON []byte
|
||||
generationConfigJSON []byte
|
||||
specialTokensMapJSON []byte
|
||||
configJSON []byte
|
||||
}
|
||||
|
||||
func applySpecialTokenConfig(t *Tokenizer, config specialTokenConfigData) {
|
||||
parseTokenIDs := func(v interface{}) []int32 {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return []int32{int32(val)}
|
||||
case []interface{}:
|
||||
ids := make([]int32, 0, len(val))
|
||||
for _, id := range val {
|
||||
if f, ok := id.(float64); ok {
|
||||
ids = append(ids, int32(f))
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Priority 1: generation_config.json
|
||||
if len(config.generationConfigJSON) > 0 {
|
||||
var genConfig struct {
|
||||
EOSTokenID interface{} `json:"eos_token_id"`
|
||||
BOSTokenID interface{} `json:"bos_token_id"`
|
||||
}
|
||||
if err := json.Unmarshal(config.generationConfigJSON, &genConfig); err == nil {
|
||||
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
|
||||
t.vocab.EOS = ids
|
||||
}
|
||||
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
|
||||
t.vocab.BOS = ids[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: config.json
|
||||
if len(config.configJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
|
||||
var modelConfig struct {
|
||||
EOSTokenID interface{} `json:"eos_token_id"`
|
||||
BOSTokenID interface{} `json:"bos_token_id"`
|
||||
}
|
||||
if err := json.Unmarshal(config.configJSON, &modelConfig); err == nil {
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
|
||||
t.vocab.EOS = ids
|
||||
}
|
||||
}
|
||||
if t.vocab.BOS < 0 {
|
||||
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
|
||||
t.vocab.BOS = ids[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: tokenizer_config.json
|
||||
if len(config.tokenizerConfigJSON) > 0 {
|
||||
var tokConfig struct {
|
||||
BOSToken interface{} `json:"bos_token"`
|
||||
EOSToken interface{} `json:"eos_token"`
|
||||
PADToken interface{} `json:"pad_token"`
|
||||
AddBOSToken *bool `json:"add_bos_token"`
|
||||
AddEOSToken *bool `json:"add_eos_token"`
|
||||
}
|
||||
if err := json.Unmarshal(config.tokenizerConfigJSON, &tokConfig); err == nil {
|
||||
if t.vocab.BOS < 0 {
|
||||
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
|
||||
if id, ok := t.specialTokens[bosStr]; ok {
|
||||
t.vocab.BOS = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
|
||||
if id, ok := t.specialTokens[eosStr]; ok {
|
||||
t.vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.vocab.PAD < 0 {
|
||||
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
|
||||
if id, ok := t.specialTokens[padStr]; ok {
|
||||
t.vocab.PAD = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if tokConfig.AddBOSToken != nil {
|
||||
t.vocab.AddBOS = *tokConfig.AddBOSToken
|
||||
}
|
||||
if tokConfig.AddEOSToken != nil {
|
||||
t.vocab.AddEOS = *tokConfig.AddEOSToken
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 4: special_tokens_map.json
|
||||
if len(config.specialTokensMapJSON) > 0 {
|
||||
var tokensMap map[string]interface{}
|
||||
if err := json.Unmarshal(config.specialTokensMapJSON, &tokensMap); err == nil {
|
||||
if t.vocab.BOS < 0 {
|
||||
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
|
||||
if id, ok := t.specialTokens[bosStr]; ok {
|
||||
t.vocab.BOS = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
|
||||
if id, ok := t.specialTokens[eosStr]; ok {
|
||||
t.vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.vocab.PAD < 0 {
|
||||
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
|
||||
if id, ok := t.specialTokens[padStr]; ok {
|
||||
t.vocab.PAD = id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractTokenString extracts the token string from various formats used in HuggingFace configs.
|
||||
// Tokens can be represented as:
|
||||
// - string: "token"
|
||||
// - object: {"content": "token", ...}
|
||||
func extractTokenString(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
// Direct string
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
// Object with content field
|
||||
if m, ok := v.(map[string]interface{}); ok {
|
||||
if content, ok := m["content"].(string); ok {
|
||||
return content
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
|
||||
// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
|
||||
// - (?!\S) negative lookahead - RE2 doesn't support this
|
||||
// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
|
||||
//
|
||||
// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
|
||||
// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
|
||||
// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
|
||||
func rewritePatternForRE2(pattern string) string {
|
||||
// Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
|
||||
pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
|
||||
|
||||
// Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
|
||||
// IMPORTANT: Must be done before the non-optional version to avoid partial replacement
|
||||
pattern = strings.ReplaceAll(pattern,
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
||||
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
|
||||
|
||||
// Expand case-insensitive contraction pattern to explicit alternations
|
||||
// (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
|
||||
pattern = strings.ReplaceAll(pattern,
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
|
||||
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
|
||||
|
||||
return pattern
|
||||
}
|
||||
|
||||
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
|
||||
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
|
||||
applySpecialTokenConfig(t, specialTokenConfigData{
|
||||
tokenizerConfigJSON: config.TokenizerConfigJSON,
|
||||
generationConfigJSON: config.GenerationConfigJSON,
|
||||
specialTokensMapJSON: config.SpecialTokensMapJSON,
|
||||
configJSON: config.ConfigJSON,
|
||||
})
|
||||
}
|
||||
|
||||
// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
|
||||
// vs GPT-2 byte-level encoding
|
||||
func detectSentencePiece(data json.RawMessage) bool {
|
||||
if data == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for Sequence decoder with Replace step (SentencePiece style)
|
||||
var seq struct {
|
||||
Type string `json:"type"`
|
||||
Decoders []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
String string `json:"String"`
|
||||
} `json:"pattern"`
|
||||
} `json:"decoders"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &seq); err == nil {
|
||||
if seq.Type == "Sequence" {
|
||||
for _, dec := range seq.Decoders {
|
||||
// Look for Replace decoder that converts ▁ to space
|
||||
if dec.Type == "Replace" && dec.Pattern.String == "▁" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for direct ByteLevel decoder (GPT-2 style)
|
||||
var simple struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &simple); err == nil {
|
||||
if simple.Type == "ByteLevel" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
|
||||
func initByteTokens(t *Tokenizer) {
|
||||
for i := range t.vocab.byteTokens {
|
||||
t.vocab.byteTokens[i] = -1
|
||||
}
|
||||
for b := 0; b < 256; b++ {
|
||||
token := fmt.Sprintf("<0x%02X>", b)
|
||||
if id, ok := t.vocab.Reverse[token]; ok {
|
||||
t.vocab.byteTokens[b] = id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
|
||||
func extractPretokenizer(data json.RawMessage) string {
|
||||
if data == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to parse as a single Split pretokenizer
|
||||
var single struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
|
||||
return single.Pattern.Regex
|
||||
}
|
||||
|
||||
// Try to parse as Sequence of pretokenizers - use first Split pattern
|
||||
var seq struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
} `json:"pretokenizers"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
|
||||
for _, pt := range seq.Pretokenizers {
|
||||
if pt.Type == "Split" && pt.Pattern.Regex != "" {
|
||||
return pt.Pattern.Regex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadFromBytesRejectsWordPiece(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model": {
|
||||
"type": "WordPiece",
|
||||
"vocab": {"[UNK]": 0, "hello": 1}
|
||||
},
|
||||
"added_tokens": []
|
||||
}`)
|
||||
|
||||
_, err := LoadFromBytes(data)
|
||||
if err == nil {
|
||||
t.Fatal("expected WordPiece load to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unsupported tokenizer type: WordPiece") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user